Skip to content

Commit

Permalink
add in redistribute option for Kafka Read
Browse files Browse the repository at this point in the history
  • Loading branch information
Naireen committed Jul 8, 2024
1 parent de4645d commit 9cbdda1
Show file tree
Hide file tree
Showing 9 changed files with 302 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.Redistribute;
import org.apache.beam.sdk.transforms.Reshuffle;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.SimpleFunction;
Expand Down Expand Up @@ -600,6 +601,9 @@ public static <K, V> Read<K, V> read() {
.setDynamicRead(false)
.setTimestampPolicyFactory(TimestampPolicyFactory.withProcessingTime())
.setConsumerPollingTimeout(2L)
.setRedistributed(false)
.setAllowDuplicates(false)
.setRedistributeNumKeys(0)
.build();
}

Expand Down Expand Up @@ -698,6 +702,15 @@ public abstract static class Read<K, V>
@Pure
public abstract boolean isDynamicRead();

@Pure
public abstract boolean isRedistributed();

@Pure
public abstract boolean isAllowDuplicates();

@Pure
public abstract int getRedistributeNumKeys();

@Pure
public abstract @Nullable Duration getWatchTopicPartitionDuration();

Expand Down Expand Up @@ -757,6 +770,12 @@ abstract Builder<K, V> setConsumerFactoryFn(

abstract Builder<K, V> setWatchTopicPartitionDuration(Duration duration);

abstract Builder<K, V> setRedistributed(boolean withRedistribute);

abstract Builder<K, V> setAllowDuplicates(boolean allowDuplicates);

abstract Builder<K, V> setRedistributeNumKeys(int redistributeNumKeys);

abstract Builder<K, V> setTimestampPolicyFactory(
TimestampPolicyFactory<K, V> timestampPolicyFactory);

Expand Down Expand Up @@ -852,6 +871,22 @@ static <K, V> void setupExternalBuilder(
} else {
builder.setConsumerPollingTimeout(2L);
}

if (config.redistribute != null) {
builder.setRedistributed(config.redistribute);
if (config.redistributeNumKeys != null) {
builder.setRedistributeNumKeys((int) config.redistributeNumKeys);
}
if (config.allowDuplicates != null) {
builder.setAllowDuplicates(config.allowDuplicates);
}

} else {
builder.setRedistributed(false);
builder.setRedistributeNumKeys(0);
builder.setAllowDuplicates(false);
}
System.out.println("xxx builder service" + builder.toString());
}

private static <T> Coder<T> resolveCoder(Class<Deserializer<T>> deserializer) {
Expand Down Expand Up @@ -916,6 +951,9 @@ public static class Configuration {
private Boolean commitOffsetInFinalize;
private Long consumerPollingTimeout;
private String timestampPolicy;
private Integer redistributeNumKeys;
private Boolean redistribute;
private Boolean allowDuplicates;

public void setConsumerConfig(Map<String, String> consumerConfig) {
this.consumerConfig = consumerConfig;
Expand Down Expand Up @@ -960,6 +998,18 @@ public void setTimestampPolicy(String timestampPolicy) {
public void setConsumerPollingTimeout(Long consumerPollingTimeout) {
this.consumerPollingTimeout = consumerPollingTimeout;
}

public void setRedistributeNumKeys(Integer redistributeNumKeys) {
this.redistributeNumKeys = redistributeNumKeys;
}

public void setRedistribute(Boolean redistribute) {
this.redistribute = redistribute;
}

public void setAllowDuplicates(Boolean allowDuplicates) {
this.allowDuplicates = allowDuplicates;
}
}
}

Expand Down Expand Up @@ -1007,6 +1057,30 @@ public Read<K, V> withTopicPartitions(List<TopicPartition> topicPartitions) {
return toBuilder().setTopicPartitions(ImmutableList.copyOf(topicPartitions)).build();
}

/**
* Sets redistribute transform that hints to the runner to try to redistribute the work evenly.
*/
public Read<K, V> withRedistribute() {
if (getRedistributeNumKeys() == 0 && isRedistributed()) {
LOG.warn("This will create a key per record, which is sub-optimal for most use cases.");
}
return toBuilder().setRedistributed(true).build();
}

public Read<K, V> withAllowDuplicates(Boolean allowDuplicates) {
if (!isAllowDuplicates()) {
LOG.warn("Setting this value without setting withRedistribute() will have no effect.");
}
return toBuilder().setAllowDuplicates(allowDuplicates).build();
}

public Read<K, V> withRedistributeNumKeys(int redistributeNumKeys) {
checkState(
isRedistributed(),
"withRedistributeNumKeys is ignored if withRedistribute() is not enabled on the transform.");
return toBuilder().setRedistributeNumKeys(redistributeNumKeys).build();
}

/**
* Internally sets a {@link java.util.regex.Pattern} of topics to read from. All the partitions
* from each of the matching topics are read.
Expand Down Expand Up @@ -1618,6 +1692,25 @@ public PCollection<KafkaRecord<K, V>> expand(PBegin input) {
.withMaxNumRecords(kafkaRead.getMaxNumRecords());
}

if (kafkaRead.isRedistributed()) {
// fail here instead.
checkArgument(
kafkaRead.isCommitOffsetsInFinalizeEnabled(),
"commitOffsetsInFinalize() can't be enabled with isRedistributed");
PCollection<KafkaRecord<K, V>> output = input.getPipeline().apply(transform);
if (kafkaRead.getRedistributeNumKeys() == 0) {
return output.apply(
"Insert Redistribute",
Redistribute.<KafkaRecord<K, V>>arbitrarily()
.withAllowDuplicates(kafkaRead.isAllowDuplicates()));
} else {
return output.apply(
"Insert Redistribute with Shards",
Redistribute.<KafkaRecord<K, V>>arbitrarily()
.withAllowDuplicates(kafkaRead.isAllowDuplicates())
.withNumBuckets((int) kafkaRead.getRedistributeNumKeys()));
}
}
return input.getPipeline().apply(transform);
}
}
Expand All @@ -1637,6 +1730,8 @@ public PCollection<KafkaRecord<K, V>> expand(PBegin input) {
.withKeyDeserializerProvider(kafkaRead.getKeyDeserializerProvider())
.withValueDeserializerProvider(kafkaRead.getValueDeserializerProvider())
.withManualWatermarkEstimator()
.withRedistribute()
.withAllowDuplicates() // must be set with withRedistribute option.
.withTimestampPolicyFactory(kafkaRead.getTimestampPolicyFactory())
.withCheckStopReadingFn(kafkaRead.getCheckStopReadingFn())
.withConsumerPollingTimeout(kafkaRead.getConsumerPollingTimeout());
Expand All @@ -1650,6 +1745,15 @@ public PCollection<KafkaRecord<K, V>> expand(PBegin input) {
readTransform =
readTransform.withBadRecordErrorHandler(kafkaRead.getBadRecordErrorHandler());
}
if (kafkaRead.isRedistributed()) {
readTransform = readTransform.withRedistribute();
}
if (kafkaRead.isAllowDuplicates()) {
readTransform = readTransform.withAllowDuplicates();
}
if (kafkaRead.getRedistributeNumKeys() > 0) {
readTransform = readTransform.withRedistributeNumKeys(kafkaRead.getRedistributeNumKeys());
}
PCollection<KafkaSourceDescriptor> output;
if (kafkaRead.isDynamicRead()) {
Set<String> topics = new HashSet<>();
Expand Down Expand Up @@ -1679,6 +1783,22 @@ public PCollection<KafkaRecord<K, V>> expand(PBegin input) {
.apply(Impulse.create())
.apply(ParDo.of(new GenerateKafkaSourceDescriptor(kafkaRead)));
}
if (kafkaRead.isRedistributed()) {
PCollection<KafkaRecord<K, V>> pcol =
output.apply(readTransform).setCoder(KafkaRecordCoder.of(keyCoder, valueCoder));
if (kafkaRead.getRedistributeNumKeys() == 0) {
return pcol.apply(
"Insert Redistribute",
Redistribute.<KafkaRecord<K, V>>arbitrarily()
.withAllowDuplicates(kafkaRead.isAllowDuplicates()));
} else {
return pcol.apply(
"Insert Redistribute with Shards",
Redistribute.<KafkaRecord<K, V>>arbitrarily()
.withAllowDuplicates(true)
.withNumBuckets((int) kafkaRead.getRedistributeNumKeys()));
}
}
return output.apply(readTransform).setCoder(KafkaRecordCoder.of(keyCoder, valueCoder));
}
}
Expand Down Expand Up @@ -2070,6 +2190,15 @@ public abstract static class ReadSourceDescriptors<K, V>
@Pure
abstract boolean isCommitOffsetEnabled();

@Pure
abstract boolean isRedistribute();

@Pure
abstract boolean isAllowDuplicates();

@Pure
abstract int getRedistributeNumKeys();

@Pure
abstract @Nullable TimestampPolicyFactory<K, V> getTimestampPolicyFactory();

Expand Down Expand Up @@ -2136,6 +2265,12 @@ abstract ReadSourceDescriptors.Builder<K, V> setBadRecordErrorHandler(

abstract ReadSourceDescriptors.Builder<K, V> setBounded(boolean bounded);

abstract ReadSourceDescriptors.Builder<K, V> setRedistribute(boolean withRedistribute);

abstract ReadSourceDescriptors.Builder<K, V> setAllowDuplicates(boolean allowDuplicates);

abstract ReadSourceDescriptors.Builder<K, V> setRedistributeNumKeys(int redistributeNumKeys);

abstract ReadSourceDescriptors<K, V> build();
}

Expand All @@ -2148,6 +2283,9 @@ public static <K, V> ReadSourceDescriptors<K, V> read() {
.setBadRecordRouter(BadRecordRouter.THROWING_ROUTER)
.setBadRecordErrorHandler(new ErrorHandler.DefaultErrorHandler<>())
.setConsumerPollingTimeout(2L)
.setRedistribute(false)
.setAllowDuplicates(false)
.setRedistributeNumKeys(0)
.build()
.withProcessingTime()
.withMonotonicallyIncreasingWatermarkEstimator();
Expand Down Expand Up @@ -2307,6 +2445,19 @@ public ReadSourceDescriptors<K, V> withProcessingTime() {
ReadSourceDescriptors.ExtractOutputTimestampFns.useProcessingTime());
}

/** Enable Redistribute. */
public ReadSourceDescriptors<K, V> withRedistribute() {
return toBuilder().setRedistribute(true).build();
}

public ReadSourceDescriptors<K, V> withAllowDuplicates() {
return toBuilder().setAllowDuplicates(true).build();
}

public ReadSourceDescriptors<K, V> withRedistributeNumKeys(int redistributeNumKeys) {
return toBuilder().setRedistributeNumKeys(redistributeNumKeys).build();
}

/** Use the creation time of {@link KafkaRecord} as the output timestamp. */
public ReadSourceDescriptors<K, V> withCreateTime() {
return withExtractOutputTimestampFn(
Expand Down Expand Up @@ -2497,6 +2648,12 @@ public PCollection<KafkaRecord<K, V>> expand(PCollection<KafkaSourceDescriptor>
}
}

if (isRedistribute()) {
if (getRedistributeNumKeys() == 0) {
LOG.warn("This will create a key per record, which is sub-optimal for most use cases.");
}
}

if (getConsumerConfig().get(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG) == null) {
LOG.warn(
"The bootstrapServers is not set. It must be populated through the KafkaSourceDescriptor during runtime otherwise the pipeline will fail.");
Expand Down Expand Up @@ -2527,7 +2684,7 @@ public PCollection<KafkaRecord<K, V>> expand(PCollection<KafkaSourceDescriptor>
.getSchemaRegistry()
.getSchemaCoder(KafkaSourceDescriptor.class),
recordCoder));
if (isCommitOffsetEnabled() && !configuredKafkaCommit()) {
if (isCommitOffsetEnabled() && !configuredKafkaCommit() && !isRedistribute()) {
outputWithDescriptor =
outputWithDescriptor
.apply(Reshuffle.viaRandomKey())
Expand All @@ -2538,6 +2695,7 @@ public PCollection<KafkaRecord<K, V>> expand(PCollection<KafkaSourceDescriptor>
.getSchemaRegistry()
.getSchemaCoder(KafkaSourceDescriptor.class),
recordCoder));

PCollection<Void> unused = outputWithDescriptor.apply(new KafkaCommitOffset<K, V>(this));
unused.setCoder(VoidCoder.of());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,24 @@ Object getDefaultValue() {
return Long.valueOf(2);
}
},
REDISTRIBUTE_NUM_KEYS {
@Override
Object getDefaultValue() {
return Integer.valueOf(0);
}
},
REDISTRIBUTED {
@Override
Object getDefaultValue() {
return false;
}
},
ALLOW_DUPLICATES {
@Override
Object getDefaultValue() {
return false;
}
},
;

private final @NonNull ImmutableSet<KafkaIOReadImplementation> supportedImplementations;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,10 @@ public void testConstructKafkaRead() throws Exception {
Field.of("start_read_time", FieldType.INT64),
Field.of("commit_offset_in_finalize", FieldType.BOOLEAN),
Field.of("timestamp_policy", FieldType.STRING),
Field.of("consumer_polling_timeout", FieldType.INT64)))
Field.of("consumer_polling_timeout", FieldType.INT64),
Field.of("redistribute_num_keys", FieldType.INT32),
Field.of("redistribute", FieldType.BOOLEAN),
Field.of("allow_duplicates", FieldType.BOOLEAN)))
.withFieldValue("topics", topics)
.withFieldValue("consumer_config", consumerConfig)
.withFieldValue("key_deserializer", keyDeserializer)
Expand All @@ -117,6 +120,9 @@ public void testConstructKafkaRead() throws Exception {
.withFieldValue("commit_offset_in_finalize", false)
.withFieldValue("timestamp_policy", "ProcessingTime")
.withFieldValue("consumer_polling_timeout", 5L)
.withFieldValue("redistribute_num_keys", 0)
.withFieldValue("redistribute", false)
.withFieldValue("allow_duplicates", false)
.build());

RunnerApi.Components defaultInstance = RunnerApi.Components.getDefaultInstance();
Expand All @@ -139,6 +145,7 @@ public void testConstructKafkaRead() throws Exception {
expansionService.expand(request, observer);
ExpansionApi.ExpansionResponse result = observer.result;
RunnerApi.PTransform transform = result.getTransform();
System.out.println("xxx : " + result.toString());
assertThat(
transform.getSubtransformsList(),
Matchers.hasItem(MatchesPattern.matchesPattern(".*KafkaIO-Read.*")));
Expand Down Expand Up @@ -237,14 +244,20 @@ public void testConstructKafkaReadWithoutMetadata() throws Exception {
Field.of("value_deserializer", FieldType.STRING),
Field.of("start_read_time", FieldType.INT64),
Field.of("commit_offset_in_finalize", FieldType.BOOLEAN),
Field.of("timestamp_policy", FieldType.STRING)))
Field.of("timestamp_policy", FieldType.STRING),
Field.of("redistribute_num_keys", FieldType.INT32),
Field.of("redistribute", FieldType.BOOLEAN),
Field.of("allow_duplicates", FieldType.BOOLEAN)))
.withFieldValue("topics", topics)
.withFieldValue("consumer_config", consumerConfig)
.withFieldValue("key_deserializer", keyDeserializer)
.withFieldValue("value_deserializer", valueDeserializer)
.withFieldValue("start_read_time", startReadTime)
.withFieldValue("commit_offset_in_finalize", false)
.withFieldValue("timestamp_policy", "ProcessingTime")
.withFieldValue("redistribute_num_keys", 0)
.withFieldValue("redistribute", false)
.withFieldValue("allow_duplicates", false)
.build());

RunnerApi.Components defaultInstance = RunnerApi.Components.getDefaultInstance();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ public void testPrimitiveKafkaIOReadPropertiesDefaultValueExistence() {

private void testReadTransformCreationWithImplementationBoundProperties(
Function<KafkaIO.Read<Integer, Long>, KafkaIO.Read<Integer, Long>> kafkaReadDecorator) {
p.apply(kafkaReadDecorator.apply(mkKafkaReadTransform(1000, null, new ValueAsTimestampFn())));
p.apply(
kafkaReadDecorator.apply(
mkKafkaReadTransform(1000, null, new ValueAsTimestampFn(), false, 0)));
p.run();
}

Expand Down
Loading

0 comments on commit 9cbdda1

Please sign in to comment.