Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify batch IT to use count instead of hash #26327

Merged
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,9 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;

import com.google.cloud.Timestamp;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
Expand All @@ -43,13 +41,10 @@
import org.apache.beam.sdk.extensions.avro.schemas.utils.AvroUtils;
import org.apache.beam.sdk.io.GenerateSequence;
import org.apache.beam.sdk.io.Read;
import org.apache.beam.sdk.io.common.HashingFn;
import org.apache.beam.sdk.io.common.IOITHelper;
import org.apache.beam.sdk.io.common.IOTestPipelineOptions;
import org.apache.beam.sdk.io.synthetic.SyntheticBoundedSource;
import org.apache.beam.sdk.io.synthetic.SyntheticSourceOptions;
import org.apache.beam.sdk.metrics.Counter;
import org.apache.beam.sdk.metrics.Metrics;
import org.apache.beam.sdk.options.Default;
import org.apache.beam.sdk.options.Description;
import org.apache.beam.sdk.options.ExperimentalOptions;
Expand All @@ -68,6 +63,7 @@
import org.apache.beam.sdk.testutils.metrics.TimeMonitor;
import org.apache.beam.sdk.testutils.publishing.InfluxDBSettings;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.Count;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.Flatten;
Expand All @@ -76,8 +72,8 @@
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.SimpleFunction;
import org.apache.beam.sdk.transforms.Values;
import org.apache.beam.sdk.transforms.windowing.CalendarWindows;
import org.apache.beam.sdk.transforms.windowing.FixedWindows;
import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.values.KV;
Expand Down Expand Up @@ -139,8 +135,6 @@ public class KafkaIOIT {

private static final Logger LOG = LoggerFactory.getLogger(KafkaIOIT.class);

private static String expectedHashcode;

private static SyntheticSourceOptions sourceOptions;

private static Options options;
Expand Down Expand Up @@ -202,17 +196,23 @@ public void testKafkaIOReadsAndWritesCorrectlyInStreaming() throws IOException {

// Use streaming pipeline to read Kafka records.
readPipeline.getOptions().as(Options.class).setStreaming(true);
readPipeline
.apply("Read from unbounded Kafka", readFromKafka().withTopic(options.getKafkaTopic()))
.apply("Measure read time", ParDo.of(new TimeMonitor<>(NAMESPACE, READ_TIME_METRIC_NAME)))
.apply("Map records to strings", MapElements.via(new MapKafkaRecordsToStrings()))
.apply("Counting element", ParDo.of(new CountingFn(NAMESPACE, READ_ELEMENT_METRIC_NAME)));
PCollection<Long> count =
readPipeline
.apply("Read from unbounded Kafka", readFromKafka().withTopic(options.getKafkaTopic()))
.apply(
"Measure read time", ParDo.of(new TimeMonitor<>(NAMESPACE, READ_TIME_METRIC_NAME)))
.apply("Window", Window.into(CalendarWindows.years(1)))
.apply(
"Counting element",
Combine.globally(Count.<KafkaRecord<byte[], byte[]>>combineFn()).withoutDefaults());

PipelineResult writeResult = writePipeline.run();
PipelineResult.State writeState = writeResult.waitUntilFinish();
// Fail the test if pipeline failed.
assertNotEquals(PipelineResult.State.FAILED, writeState);

PAssert.thatSingleton(count).isEqualTo(sourceOptions.numRecords);

PipelineResult readResult = readPipeline.run();
PipelineResult.State readState =
readResult.waitUntilFinish(Duration.standardSeconds(options.getReadTimeout()));
Expand All @@ -221,13 +221,6 @@ public void testKafkaIOReadsAndWritesCorrectlyInStreaming() throws IOException {
tearDownTopic(options.getKafkaTopic());
cancelIfTimeouted(readResult, readState);

long actualRecords = readElementMetric(readResult, NAMESPACE, READ_ELEMENT_METRIC_NAME);
assertTrue(
String.format(
"actual number of records %d smaller than expected: %d.",
actualRecords, sourceOptions.numRecords),
sourceOptions.numRecords <= actualRecords);

if (!options.isWithTestcontainers()) {
Set<NamedTestResult> metrics = readMetrics(writeResult, readResult);
IOITMetrics.publishToInflux(TEST_ID, TIMESTAMP, metrics, settings);
Expand All @@ -237,32 +230,25 @@ public void testKafkaIOReadsAndWritesCorrectlyInStreaming() throws IOException {

@Test
public void testKafkaIOReadsAndWritesCorrectlyInBatch() throws IOException {
// Map of hashes of set size collections with 100b records - 10b key, 90b values.
Map<Long, String> expectedHashes =
ImmutableMap.of(
1000L, "4507649971ee7c51abbb446e65a5c660",
100_000_000L, "0f12c27c9a7672e14775594be66cad9a");
expectedHashcode = getHashForRecordCount(sourceOptions.numRecords, expectedHashes);
writePipeline
.apply("Generate records", Read.from(new SyntheticBoundedSource(sourceOptions)))
.apply("Measure write time", ParDo.of(new TimeMonitor<>(NAMESPACE, WRITE_TIME_METRIC_NAME)))
.apply("Write to Kafka", writeToKafka().withTopic(options.getKafkaTopic()));

PCollection<String> hashcode =
PCollection<Long> count =
readPipeline
.apply(
"Read from bounded Kafka",
readFromBoundedKafka().withTopic(options.getKafkaTopic()))
.apply(
"Measure read time", ParDo.of(new TimeMonitor<>(NAMESPACE, READ_TIME_METRIC_NAME)))
.apply("Map records to strings", MapElements.via(new MapKafkaRecordsToStrings()))
.apply("Calculate hashcode", Combine.globally(new HashingFn()).withoutDefaults());

PAssert.thatSingleton(hashcode).isEqualTo(expectedHashcode);
.apply("Counting element", Count.globally());

PipelineResult writeResult = writePipeline.run();
writeResult.waitUntilFinish();

PAssert.thatSingleton(count).isEqualTo(sourceOptions.numRecords);

PipelineResult readResult = readPipeline.run();
PipelineResult.State readState =
readResult.waitUntilFinish(Duration.standardSeconds(options.getReadTimeout()));
Expand All @@ -271,8 +257,7 @@ public void testKafkaIOReadsAndWritesCorrectlyInBatch() throws IOException {
tearDownTopic(options.getKafkaTopic());
cancelIfTimeouted(readResult, readState);

// Fail the test if pipeline failed.
assertEquals(PipelineResult.State.DONE, readState);
assertNotEquals(PipelineResult.State.FAILED, readState);

if (!options.isWithTestcontainers()) {
Set<NamedTestResult> metrics = readMetrics(writeResult, readResult);
Expand Down Expand Up @@ -687,9 +672,7 @@ private PipelineResult runWithStopReadingFn(
readFromKafka()
.withTopic(options.getKafkaTopic() + "-" + topicSuffix)
.withCheckStopReadingFn(function))
.apply("Measure read time", ParDo.of(new TimeMonitor<>(NAMESPACE, READ_TIME_METRIC_NAME)))
.apply("Map records to strings", MapElements.via(new MapKafkaRecordsToStrings()))
.apply("Counting element", ParDo.of(new CountingFn(NAMESPACE, READ_ELEMENT_METRIC_NAME)));
.apply("Measure read time", ParDo.of(new TimeMonitor<>(NAMESPACE, READ_TIME_METRIC_NAME)));

PipelineResult writeResult = writePipeline.run();
writeResult.waitUntilFinish();
Expand Down Expand Up @@ -834,19 +817,6 @@ private KafkaIO.Read<byte[], byte[]> readFromKafka() {
.withConsumerConfigUpdates(ImmutableMap.of("auto.offset.reset", "earliest"));
}

private static class CountingFn extends DoFn<String, Void> {

private final Counter elementCounter;

CountingFn(String namespace, String name) {
elementCounter = Metrics.counter(namespace, name);
}

@ProcessElement
public void processElement() {
elementCounter.inc(1L);
}
}
/** Pipeline options specific for this test. */
public interface Options extends IOTestPipelineOptions, StreamingOptions {

Expand Down Expand Up @@ -887,25 +857,6 @@ public interface Options extends IOTestPipelineOptions, StreamingOptions {
void setKafkaContainerVersion(String kafkaContainerVersion);
}

private static class MapKafkaRecordsToStrings
extends SimpleFunction<KafkaRecord<byte[], byte[]>, String> {
@Override
public String apply(KafkaRecord<byte[], byte[]> input) {
String key = Arrays.toString(input.getKV().getKey());
String value = Arrays.toString(input.getKV().getValue());
return String.format("%s %s", key, value);
}
}

public static String getHashForRecordCount(long recordCount, Map<Long, String> hashes) {
String hash = hashes.get(recordCount);
if (hash == null) {
throw new UnsupportedOperationException(
String.format("No hash for that record count: %s", recordCount));
}
return hash;
}

private static void setupKafkaContainer() {
kafkaContainer =
new KafkaContainer(
Expand Down