Skip to content

Commit

Permalink
Change StateBackedIterable to implement ElementByteSizeObservableIter…
Browse files Browse the repository at this point in the history
…able avoiding iteration to estimate observe bytes. (#29517)

* Change StateBackedIterable to implement ElementByteSizeObservableIterable reducing byte estimation costs.
  • Loading branch information
scwhittle authored Nov 23, 2023
1 parent 183141d commit 6ce1047
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,17 @@
import org.apache.beam.sdk.fn.stream.PrefetchableIterators;
import org.apache.beam.sdk.util.BufferedElementCountingOutputStream;
import org.apache.beam.sdk.util.VarInt;
import org.apache.beam.sdk.util.common.ElementByteSizeObservableIterable;
import org.apache.beam.sdk.util.common.ElementByteSizeObservableIterator;
import org.apache.beam.sdk.util.common.ElementByteSizeObserver;
import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.ByteStreams;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* A {@link BeamFnStateClient state} backed iterable which allows for fetching elements over the
Expand All @@ -62,12 +67,17 @@
@SuppressWarnings({
"rawtypes" // TODO(https://github.com/apache/beam/issues/20447)
})
public class StateBackedIterable<T> implements Iterable<T>, Serializable {
public class StateBackedIterable<T>
extends ElementByteSizeObservableIterable<T, ElementByteSizeObservableIterator<T>>
implements Serializable {
private static final Logger LOG = LoggerFactory.getLogger(StateBackedIterable.class);

@VisibleForTesting final StateRequest request;
@VisibleForTesting final List<T> prefix;
private final transient PrefetchableIterable<T> suffix;

private final org.apache.beam.sdk.coders.Coder<T> elemCoder;

public StateBackedIterable(
Cache<?, ?> cache,
BeamFnStateClient beamFnStateClient,
Expand All @@ -81,11 +91,82 @@ public StateBackedIterable(
this.suffix =
StateFetchingIterators.readAllAndDecodeStartingFrom(
Caches.subCache(cache, stateKey), beamFnStateClient, request, elemCoder);
this.elemCoder = elemCoder;
}

@SuppressWarnings("nullness")
private static class WrappedObservingIterator<T> extends ElementByteSizeObservableIterator<T> {
private final Iterator<T> wrappedIterator;
private final org.apache.beam.sdk.coders.Coder<T> elementCoder;

// Logically final and non-null but initialized after construction by factory method for
// initialization ordering.
private ElementByteSizeObserver observerProxy = null;

private boolean observerNeedsAdvance = false;
private boolean exceptionLogged = false;

static <T> WrappedObservingIterator<T> create(
Iterator<T> iterator, org.apache.beam.sdk.coders.Coder<T> elementCoder) {
WrappedObservingIterator<T> result = new WrappedObservingIterator<>(iterator, elementCoder);
result.observerProxy =
new ElementByteSizeObserver() {
@Override
protected void reportElementSize(long elementByteSize) {
result.notifyValueReturned(elementByteSize);
}
};
return result;
}

private WrappedObservingIterator(
Iterator<T> iterator, org.apache.beam.sdk.coders.Coder<T> elementCoder) {
this.wrappedIterator = iterator;
this.elementCoder = elementCoder;
}

@Override
public boolean hasNext() {
if (observerNeedsAdvance) {
observerProxy.advance();
observerNeedsAdvance = false;
}
return wrappedIterator.hasNext();
}

@Override
public T next() {
T value = wrappedIterator.next();
try {
elementCoder.registerByteSizeObserver(value, observerProxy);
if (observerProxy.getIsLazy()) {
// The observer will only be notified of bytes as the result
// is used. We defer advancing the observer until hasNext in an
// attempt to capture those bytes.
observerNeedsAdvance = true;
} else {
observerNeedsAdvance = false;
observerProxy.advance();
}
} catch (Exception e) {
if (!exceptionLogged) {
LOG.warn("Lazily observed byte size will be under reported due to exception", e);
exceptionLogged = true;
}
}
return value;
}

@Override
public void remove() {
super.remove();
}
}

@Override
public Iterator<T> iterator() {
return PrefetchableIterators.concat(prefix.iterator(), suffix.iterator());
protected ElementByteSizeObservableIterator<T> createIterator() {
return WrappedObservingIterator.create(
PrefetchableIterators.concat(prefix.iterator(), suffix.iterator()), elemCoder);
}

protected Object writeReplace() throws ObjectStreamException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import static java.util.Arrays.asList;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
Expand All @@ -36,11 +37,13 @@
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.util.ByteStringOutputStream;
import org.apache.beam.sdk.util.common.ElementByteSizeObserver;
import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.FluentIterable;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Streams;
import org.junit.Test;
import org.junit.experimental.runners.Enclosed;
import org.junit.runner.RunWith;
Expand Down Expand Up @@ -213,6 +216,61 @@ public void testUsingInterleavedReiteration() throws Exception {
}
}
}

private static class TestByteObserver extends ElementByteSizeObserver {
public long total = 0;

@Override
protected void reportElementSize(long elementByteSize) {
total += elementByteSize;
}
};

@Test
public void testByteObservingStateBackedIterable() throws Exception {
FakeBeamFnStateClient fakeBeamFnStateClient =
new FakeBeamFnStateClient(
StringUtf8Coder.of(),
ImmutableMap.of(
key("nonEmptySuffix"), asList("C", "D", "E", "F", "G", "H", "I", "J", "K"),
key("emptySuffix"), asList()));

StateBackedIterable<String> iterable =
new StateBackedIterable<>(
Caches.noop(),
fakeBeamFnStateClient,
"instruction",
key(suffixKey),
StringUtf8Coder.of(),
prefix);
StateBackedIterable.Coder<String> coder =
new StateBackedIterable.Coder<>(
() -> Caches.noop(),
fakeBeamFnStateClient,
() -> "instructionId",
StringUtf8Coder.of());

assertTrue(coder.isRegisterByteSizeObserverCheap(iterable));
TestByteObserver observer = new TestByteObserver();
coder.registerByteSizeObserver(iterable, observer);
assertTrue(observer.getIsLazy());

long iterateBytes =
Streams.stream(iterable)
.mapToLong(
s -> {
try {
// 1 comes from hasNext = true flag (see IterableLikeCoder)
return 1 + StringUtf8Coder.of().getEncodedElementByteSize(s);
} catch (Exception e) {
throw new RuntimeException(e);
}
})
.sum();
observer.advance();
// 5 comes from size and hasNext (see IterableLikeCoder)
assertEquals(iterateBytes + 5, observer.total);
}
}

@RunWith(JUnit4.class)
Expand Down

0 comments on commit 6ce1047

Please sign in to comment.