Skip to content

Commit

Permalink
Change UnboundedSourceAsSdfWrapperFn to share the cache across instan…
Browse files Browse the repository at this point in the history
…ces. (#33901)

add a utility class to enable sharing across all deserialized instances of a DoFn and use it in UnboundedSourceAsSdfWrapperFn to cache Readers across dofn instances
  • Loading branch information
scwhittle authored Feb 25, 2025
1 parent 87f0ed3 commit 16f7bb6
Show file tree
Hide file tree
Showing 4 changed files with 229 additions and 31 deletions.
63 changes: 38 additions & 25 deletions sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import java.util.Arrays;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderException;
Expand All @@ -51,10 +53,10 @@
import org.apache.beam.sdk.transforms.splittabledofn.ManualWatermarkEstimator;
import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker.HasProgress;
import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker.Progress;
import org.apache.beam.sdk.transforms.splittabledofn.SplitResult;
import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimators;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.util.MemoizingPerInstantiationSerializableSupplier;
import org.apache.beam.sdk.util.NameUtils;
import org.apache.beam.sdk.util.SerializableUtils;
import org.apache.beam.sdk.values.PBegin;
Expand All @@ -65,10 +67,11 @@
import org.apache.beam.sdk.values.ValueWithRecordId.StripIdsDoFn;
import org.apache.beam.sdk.values.ValueWithRecordId.ValueWithRecordIdCoder;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.Cache;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheBuilder;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.RemovalCause;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.RemovalListener;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.checkerframework.checker.nullness.qual.EnsuresNonNull;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.checkerframework.common.value.qual.ArrayLen;
Expand Down Expand Up @@ -481,12 +484,37 @@ static class UnboundedSourceAsSDFWrapperFn<OutputT, CheckpointT extends Checkpoi
private static final Logger LOG = LoggerFactory.getLogger(UnboundedSourceAsSDFWrapperFn.class);
private static final int DEFAULT_BUNDLE_FINALIZATION_LIMIT_MINS = 10;
private final Coder<CheckpointT> checkpointCoder;
private @Nullable Cache<Object, UnboundedReader<OutputT>> cachedReaders;
private final MemoizingPerInstantiationSerializableSupplier<
Cache<Object, UnboundedReader<OutputT>>>
readerCacheSupplier;
private static final Executor closeExecutor =
Executors.newCachedThreadPool(
new ThreadFactoryBuilder().setNameFormat("UnboundedReaderCloses-%d").build());
private @Nullable Coder<UnboundedSourceRestriction<OutputT, CheckpointT>> restrictionCoder;

@VisibleForTesting
UnboundedSourceAsSDFWrapperFn(Coder<CheckpointT> checkpointCoder) {
this.checkpointCoder = checkpointCoder;
readerCacheSupplier =
new MemoizingPerInstantiationSerializableSupplier<>(
() ->
CacheBuilder.newBuilder()
.expireAfterWrite(1, TimeUnit.MINUTES)
.removalListener(
(RemovalListener<Object, UnboundedReader<OutputT>>)
removalNotification -> {
if (removalNotification.getCause() != RemovalCause.EXPLICIT) {
closeExecutor.execute(
() -> {
try {
checkStateNotNull(removalNotification.getValue()).close();
} catch (IOException e) {
LOG.warn("Failed to close UnboundedReader.", e);
}
});
}
})
.build());
}

@GetInitialRestriction
Expand All @@ -498,22 +526,6 @@ public UnboundedSourceRestriction<OutputT, CheckpointT> initialRestriction(
@Setup
public void setUp() throws Exception {
restrictionCoder = restrictionCoder();
cachedReaders =
CacheBuilder.newBuilder()
.expireAfterWrite(1, TimeUnit.MINUTES)
.maximumSize(100)
.removalListener(
(RemovalListener<Object, UnboundedReader<OutputT>>)
removalNotification -> {
if (removalNotification.wasEvicted()) {
try {
Preconditions.checkNotNull(removalNotification.getValue()).close();
} catch (IOException e) {
LOG.warn("Failed to close UnboundedReader.", e);
}
}
})
.build();
}

@SplitRestriction
Expand Down Expand Up @@ -556,7 +568,8 @@ public void splitRestriction(
PipelineOptions pipelineOptions) {
Coder<UnboundedSourceRestriction<OutputT, CheckpointT>> restrictionCoder =
checkStateNotNull(this.restrictionCoder);
Cache<Object, UnboundedReader<OutputT>> cachedReaders = checkStateNotNull(this.cachedReaders);
Cache<Object, UnboundedReader<OutputT>> cachedReaders =
checkStateNotNull(this.readerCacheSupplier.get());
return new UnboundedSourceAsSDFRestrictionTracker<>(
restriction, pipelineOptions, cachedReaders, restrictionCoder);
}
Expand Down Expand Up @@ -840,10 +853,11 @@ private static class UnboundedSourceAsSDFRestrictionTracker<
implements HasProgress {
private final UnboundedSourceRestriction<OutputT, CheckpointT> initialRestriction;
private final PipelineOptions pipelineOptions;
private final Cache<Object, UnboundedReader<OutputT>> cachedReaders;
private final Coder<UnboundedSourceRestriction<OutputT, CheckpointT>> restrictionCoder;

private UnboundedSource.@Nullable UnboundedReader<OutputT> currentReader;
private boolean readerHasBeenStarted;
private Cache<Object, UnboundedReader<OutputT>> cachedReaders;
private Coder<UnboundedSourceRestriction<OutputT, CheckpointT>> restrictionCoder;

UnboundedSourceAsSDFRestrictionTracker(
UnboundedSourceRestriction<OutputT, CheckpointT> initialRestriction,
Expand All @@ -870,7 +884,8 @@ private void initializeCurrentReader() throws IOException {
checkState(currentReader == null);
Object cacheKey =
createCacheKey(initialRestriction.getSource(), initialRestriction.getCheckpoint());
UnboundedReader<OutputT> cachedReader = cachedReaders.getIfPresent(cacheKey);
// We remove the reader if cached so that it is not possibly claimed by multiple DoFns.
UnboundedReader<OutputT> cachedReader = cachedReaders.asMap().remove(cacheKey);

if (cachedReader == null) {
this.currentReader =
Expand All @@ -879,9 +894,7 @@ private void initializeCurrentReader() throws IOException {
.createReader(pipelineOptions, initialRestriction.getCheckpoint());
} else {
// If the reader is from cache, then we know that the reader has been started.
// We also remove this cache entry to avoid eviction.
readerHasBeenStarted = true;
cachedReaders.invalidate(cacheKey);
this.currentReader = cachedReader;
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.util;

import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import javax.annotation.Nullable;
import org.checkerframework.checker.nullness.qual.MonotonicNonNull;
import org.checkerframework.checker.nullness.qual.NonNull;

/**
* A supplier that memoizes within an instantiation across serialization/deserialization.
*
* <p>Specifically the wrapped supplier will be called once and the result memoized per group
* consisting of an instance and all instances deserialized from its serialized state.
*
* <p>A particular use for this is within a DoFn class to maintain shared state across all instances
* of the DoFn that correspond to same step in the graph but separate from other steps in the graph
* using the same DoFn. This differs from a static variable which would be shared across all
* instances of the DoFn and a non-static variable which is per instance.
*/
public class MemoizingPerInstantiationSerializableSupplier<T> implements SerializableSupplier<T> {
private static final AtomicInteger idGenerator = new AtomicInteger();
private final int id;

private static final ConcurrentHashMap<Integer, Object> staticCache = new ConcurrentHashMap<>();
private final SerializableSupplier<@NonNull T> supplier;
private transient volatile @MonotonicNonNull T value;

public MemoizingPerInstantiationSerializableSupplier(SerializableSupplier<@NonNull T> supplier) {
id = idGenerator.incrementAndGet();
this.supplier = supplier;
}

@Override
@SuppressWarnings("unchecked")
public T get() {
@Nullable T result = value;
if (result != null) {
return result;
}
@Nullable T mapValue = (T) staticCache.computeIfAbsent(id, ignored -> supplier.get());
return value = Preconditions.checkStateNotNull(mapValue);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,7 @@ public void testUnboundedSdfWrapperCacheStartedReaders() {
// read is default.
ExperimentalOptions.addExperiment(
pipeline.getOptions().as(ExperimentalOptions.class), "use_sdf_read");
// Force the pipeline to run with one thread to ensure the reader will be reused on one DoFn
// instance.
// We are not able to use DirectOptions because of circular dependency.
pipeline
.runWithAdditionalOptionArgs(ImmutableList.of("--targetParallelism=1"))
.waitUntilFinish();
pipeline.run().waitUntilFinish();
}

@Test
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.util;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertSame;

import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
public class MemoizingPerInstantiationSerializableSupplierTest {

@SuppressWarnings("unchecked")
@Test
public void testSharedAcrossDeserialize() throws Exception {
MemoizingPerInstantiationSerializableSupplier<AtomicInteger> instance =
new MemoizingPerInstantiationSerializableSupplier<>(AtomicInteger::new);
SerializableUtils.ensureSerializable(instance);

AtomicInteger i = instance.get();
i.set(10);
assertSame(i, instance.get());

byte[] serialized = SerializableUtils.serializeToByteArray(instance);
MemoizingPerInstantiationSerializableSupplier<AtomicInteger> deserialized1 =
(MemoizingPerInstantiationSerializableSupplier<AtomicInteger>)
SerializableUtils.deserializeFromByteArray(serialized, "instance");
assertSame(i, deserialized1.get());

MemoizingPerInstantiationSerializableSupplier<AtomicInteger> deserialized2 =
(MemoizingPerInstantiationSerializableSupplier<AtomicInteger>)
SerializableUtils.deserializeFromByteArray(serialized, "instance");
assertSame(i, deserialized2.get());
assertEquals(10, i.get());
}

@Test
public void testDifferentInstancesSeparate() throws Exception {
MemoizingPerInstantiationSerializableSupplier<AtomicInteger> instance =
new MemoizingPerInstantiationSerializableSupplier<>(AtomicInteger::new);
SerializableUtils.ensureSerializable(instance);
AtomicInteger i = instance.get();
i.set(10);
assertSame(i, instance.get());

MemoizingPerInstantiationSerializableSupplier<AtomicInteger> instance2 =
new MemoizingPerInstantiationSerializableSupplier<>(AtomicInteger::new);
SerializableUtils.ensureSerializable(instance2);
AtomicInteger j = instance2.get();
j.set(20);
assertSame(j, instance2.get());
assertNotSame(j, i);

MemoizingPerInstantiationSerializableSupplier<AtomicInteger> instance1clone =
SerializableUtils.clone(instance);
assertSame(instance1clone.get(), i);
MemoizingPerInstantiationSerializableSupplier<AtomicInteger> instance2clone =
SerializableUtils.clone(instance2);
assertSame(instance2clone.get(), j);
}

@SuppressWarnings("unchecked")
@Test
public void testDifferentInstancesSeparateNoGetBeforeSerialization() throws Exception {
MemoizingPerInstantiationSerializableSupplier<AtomicInteger> instance =
new MemoizingPerInstantiationSerializableSupplier<>(AtomicInteger::new);
SerializableUtils.ensureSerializable(instance);

MemoizingPerInstantiationSerializableSupplier<AtomicInteger> instance2 =
new MemoizingPerInstantiationSerializableSupplier<>(AtomicInteger::new);
SerializableUtils.ensureSerializable(instance2);

byte[] serialized = SerializableUtils.serializeToByteArray(instance);
MemoizingPerInstantiationSerializableSupplier<AtomicInteger> deserialized1 =
(MemoizingPerInstantiationSerializableSupplier<AtomicInteger>)
SerializableUtils.deserializeFromByteArray(serialized, "instance");
MemoizingPerInstantiationSerializableSupplier<AtomicInteger> deserialized2 =
(MemoizingPerInstantiationSerializableSupplier<AtomicInteger>)
SerializableUtils.deserializeFromByteArray(serialized, "instance");
assertSame(deserialized1.get(), deserialized2.get());

MemoizingPerInstantiationSerializableSupplier<AtomicInteger> instance2clone =
SerializableUtils.clone(instance2);
assertNotSame(instance2clone.get(), deserialized1.get());
}

@Test
public void testDifferentTypes() throws Exception {
MemoizingPerInstantiationSerializableSupplier<AtomicInteger> instance =
new MemoizingPerInstantiationSerializableSupplier<>(AtomicInteger::new);
SerializableUtils.ensureSerializable(instance);
AtomicInteger i = instance.get();
i.set(10);
assertSame(i, instance.get());

MemoizingPerInstantiationSerializableSupplier<ConcurrentHashMap<Integer, Integer>> instance2 =
new MemoizingPerInstantiationSerializableSupplier<>(ConcurrentHashMap::new);
SerializableUtils.ensureSerializable(instance2);
ConcurrentHashMap<Integer, Integer> j = instance2.get();
j.put(1, 100);
assertSame(j, instance2.get());

MemoizingPerInstantiationSerializableSupplier<AtomicInteger> instance1clone =
SerializableUtils.clone(instance);
assertSame(instance1clone.get(), i);
MemoizingPerInstantiationSerializableSupplier<ConcurrentHashMap<Integer, Integer>>
instance2clone = SerializableUtils.clone(instance2);
assertSame(instance2clone.get(), j);
}
}

0 comments on commit 16f7bb6

Please sign in to comment.