From 156102b7975515a75903bd41dd4534bb791f2e36 Mon Sep 17 00:00:00 2001 From: Martin Trieu Date: Fri, 29 Sep 2023 17:31:21 -0700 Subject: [PATCH 1/2] Refactor StateFetcher --- .../runners/dataflow/worker/StateFetcher.java | 291 ------------------ .../worker/StreamingDataflowWorker.java | 11 +- .../worker/StreamingModeExecutionContext.java | 114 ++++--- .../worker/StreamingSideInputFetcher.java | 8 +- .../worker/streaming/sideinput/SideInput.java | 43 +++ .../streaming/sideinput/SideInputCache.java | 93 ++++++ .../streaming/sideinput/SideInputState.java | 25 ++ .../sideinput/SideInputStateFetcher.java | 215 +++++++++++++ .../worker/StreamingDataflowWorkerTest.java | 6 +- .../StreamingModeExecutionContextTest.java | 9 +- .../StreamingSideInputDoFnRunnerTest.java | 2 +- .../worker/StreamingSideInputFetcherTest.java | 2 +- .../sideinput/SideInputStateFetcherTest.java} | 161 ++++++---- 13 files changed, 562 insertions(+), 418 deletions(-) delete mode 100644 runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StateFetcher.java create mode 100644 runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInput.java create mode 100644 runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputCache.java create mode 100644 runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputState.java create mode 100644 runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcher.java rename runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/{StateFetcherTest.java => streaming/sideinput/SideInputStateFetcherTest.java} (70%) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StateFetcher.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StateFetcher.java deleted file mode 100644 index 0cbcd2e83012..000000000000 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StateFetcher.java +++ /dev/null @@ -1,291 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.dataflow.worker; - -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; - -import java.io.Closeable; -import java.util.Collections; -import java.util.Objects; -import java.util.Set; -import java.util.concurrent.Callable; -import java.util.concurrent.TimeUnit; -import org.apache.beam.runners.core.InMemoryMultimapSideInputView; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.IterableCoder; -import org.apache.beam.sdk.coders.KvCoder; -import org.apache.beam.sdk.transforms.Materializations; -import org.apache.beam.sdk.transforms.Materializations.IterableView; -import org.apache.beam.sdk.transforms.Materializations.MultimapView; -import org.apache.beam.sdk.transforms.ViewFn; -import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.util.ByteStringOutputStream; -import org.apache.beam.sdk.values.PCollectionView; -import org.apache.beam.sdk.values.TupleTag; -import org.apache.beam.sdk.values.WindowingStrategy; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Optional; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Supplier; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.Cache; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheBuilder; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.Weigher; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; -import org.checkerframework.checker.nullness.qual.Nullable; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** Class responsible for fetching state from the windmill server. */ -@SuppressWarnings({ - "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) -class StateFetcher { - private static final Set SUPPORTED_MATERIALIZATIONS = - ImmutableSet.of( - Materializations.ITERABLE_MATERIALIZATION_URN, - Materializations.MULTIMAP_MATERIALIZATION_URN); - - private static final Logger LOG = LoggerFactory.getLogger(StateFetcher.class); - - private Cache sideInputCache; - private MetricTrackingWindmillServerStub server; - private long bytesRead = 0L; - - public StateFetcher(MetricTrackingWindmillServerStub server) { - this( - server, - CacheBuilder.newBuilder() - .maximumWeight(100000000 /* 100 MB */) - .expireAfterWrite(1, TimeUnit.MINUTES) - .weigher((Weigher) (id, entry) -> entry.size()) - .build()); - } - - public StateFetcher( - MetricTrackingWindmillServerStub server, - Cache sideInputCache) { - this.server = server; - this.sideInputCache = sideInputCache; - } - - /** Returns a view of the underlying cache that keeps track of bytes read separately. */ - public StateFetcher byteTrackingView() { - return new StateFetcher(server, sideInputCache); - } - - public long getBytesRead() { - return bytesRead; - } - - /** Indicates the caller's knowledge of whether a particular side input has been computed. */ - public enum SideInputState { - CACHED_IN_WORKITEM, - KNOWN_READY, - UNKNOWN; - } - - /** - * Fetch the given side input, storing it in a process-level cache. - * - *

If state is KNOWN_READY, attempt to fetch the data regardless of whether a not-ready entry - * was cached. - * - *

Returns {@literal null} if the side input was not ready, {@literal Optional.absent()} if the - * side input was null, and {@literal Optional.present(...)} if the side input was non-null. - */ - public @Nullable Optional fetchSideInput( - final PCollectionView view, - final SideWindowT sideWindow, - final String stateFamily, - SideInputState state, - final Supplier scopedReadStateSupplier) { - final SideInputId id = new SideInputId(view.getTagInternal(), sideWindow); - - Callable fetchCallable = - () -> { - @SuppressWarnings("unchecked") - WindowingStrategy sideWindowStrategy = - (WindowingStrategy) view.getWindowingStrategyInternal(); - - Coder windowCoder = sideWindowStrategy.getWindowFn().windowCoder(); - - ByteStringOutputStream windowStream = new ByteStringOutputStream(); - windowCoder.encode(sideWindow, windowStream, Coder.Context.OUTER); - - @SuppressWarnings("unchecked") - Windmill.GlobalDataRequest request = - Windmill.GlobalDataRequest.newBuilder() - .setDataId( - Windmill.GlobalDataId.newBuilder() - .setTag(view.getTagInternal().getId()) - .setVersion(windowStream.toByteString()) - .build()) - .setStateFamily(stateFamily) - .setExistenceWatermarkDeadline( - WindmillTimeUtils.harnessToWindmillTimestamp( - sideWindowStrategy - .getTrigger() - .getWatermarkThatGuaranteesFiring(sideWindow))) - .build(); - - Windmill.GlobalData data; - try (Closeable scope = scopedReadStateSupplier.get()) { - data = server.getSideInputData(request); - } - - bytesRead += data.getSerializedSize(); - - checkState( - SUPPORTED_MATERIALIZATIONS.contains(view.getViewFn().getMaterialization().getUrn()), - "Only materializations of type %s supported, received %s", - SUPPORTED_MATERIALIZATIONS, - view.getViewFn().getMaterialization().getUrn()); - - Iterable rawData; - if (data.getIsReady()) { - if (data.getData().size() > 0) { - rawData = - IterableCoder.of(view.getCoderInternal()) - .decode(data.getData().newInput(), Coder.Context.OUTER); - } else { - rawData = Collections.emptyList(); - } - - switch (view.getViewFn().getMaterialization().getUrn()) { - case Materializations.ITERABLE_MATERIALIZATION_URN: - { - ViewFn viewFn = (ViewFn) view.getViewFn(); - return SideInputCacheEntry.ready( - viewFn.apply(() -> rawData), data.getData().size()); - } - case Materializations.MULTIMAP_MATERIALIZATION_URN: - { - ViewFn viewFn = (ViewFn) view.getViewFn(); - Coder keyCoder = ((KvCoder) view.getCoderInternal()).getKeyCoder(); - return SideInputCacheEntry.ready( - viewFn.apply( - InMemoryMultimapSideInputView.fromIterable(keyCoder, (Iterable) rawData)), - data.getData().size()); - } - default: - throw new IllegalStateException( - String.format( - "Unknown side input materialization format requested '%s'", - view.getViewFn().getMaterialization().getUrn())); - } - } else { - return SideInputCacheEntry.notReady(); - } - }; - - try { - if (state == SideInputState.KNOWN_READY) { - SideInputCacheEntry entry = sideInputCache.getIfPresent(id); - if (entry == null) { - return sideInputCache.get(id, fetchCallable).getValue(); - } else if (!entry.isReady()) { - // Invalidate the existing not-ready entry. This must be done atomically - // so that another thread doesn't replace the entry with a ready entry, which - // would then be deleted here. - synchronized (entry) { - SideInputCacheEntry newEntry = sideInputCache.getIfPresent(id); - if (newEntry != null && !newEntry.isReady()) { - sideInputCache.invalidate(id); - } - } - - return sideInputCache.get(id, fetchCallable).getValue(); - } else { - return entry.getValue(); - } - } else { - return sideInputCache.get(id, fetchCallable).getValue(); - } - } catch (Exception e) { - LOG.error("Fetch failed: ", e); - throw new RuntimeException("Exception while fetching side input: ", e); - } - } - - /** Struct representing a side input for a particular window. */ - static class SideInputId { - private final TupleTag tag; - private final BoundedWindow window; - - public SideInputId(TupleTag tag, BoundedWindow window) { - this.tag = tag; - this.window = window; - } - - @Override - public boolean equals(@Nullable Object other) { - if (other instanceof SideInputId) { - SideInputId otherId = (SideInputId) other; - return tag.equals(otherId.tag) && window.equals(otherId.window); - } - return false; - } - - @Override - public int hashCode() { - return Objects.hash(tag, window); - } - } - - /** - * Entry in the side input cache that stores the value (null if not ready), and the encoded size - * of the value. - */ - static class SideInputCacheEntry { - private final boolean ready; - private final Object value; - private final int encodedSize; - - private SideInputCacheEntry(boolean ready, Object value, int encodedSize) { - this.ready = ready; - this.value = value; - this.encodedSize = encodedSize; - } - - public static SideInputCacheEntry ready(Object value, int encodedSize) { - return new SideInputCacheEntry(true, value, encodedSize); - } - - public static SideInputCacheEntry notReady() { - return new SideInputCacheEntry(false, null, 0); - } - - public boolean isReady() { - return ready; - } - - /** - * Returns {@literal null} if the side input was not ready, {@literal Optional.absent()} if the - * side input was null, and {@literal Optional.present(...)} if the side input was non-null. - */ - public @Nullable Optional getValue() { - @SuppressWarnings("unchecked") - T typed = (T) value; - return ready ? Optional.fromNullable(typed) : null; - } - - public int size() { - return encodedSize; - } - } -} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java index 4c1693d61387..77f5205cf7e9 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java @@ -94,6 +94,7 @@ import org.apache.beam.runners.dataflow.worker.streaming.WeightedBoundedQueue; import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.streaming.Work.State; +import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; import org.apache.beam.runners.dataflow.worker.util.MemoryMonitor; import org.apache.beam.runners.dataflow.worker.util.common.worker.ElementCounter; @@ -228,7 +229,7 @@ public class StreamingDataflowWorker { private final Thread commitThread; private final AtomicLong activeCommitBytes = new AtomicLong(); private final AtomicBoolean running = new AtomicBoolean(); - private final StateFetcher stateFetcher; + private final SideInputStateFetcher sideInputStateFetcher; private final StreamingDataflowWorkerOptions options; private final boolean windmillServiceEnabled; private final long clientId; @@ -406,7 +407,7 @@ public void run() { this.metricTrackingWindmillServer = new MetricTrackingWindmillServerStub(windmillServer, memoryMonitor, windmillServiceEnabled); this.metricTrackingWindmillServer.start(); - this.stateFetcher = new StateFetcher(metricTrackingWindmillServer); + this.sideInputStateFetcher = new SideInputStateFetcher(metricTrackingWindmillServer); this.clientId = clientIdGenerator.nextLong(); for (MapTask mapTask : mapTasks) { @@ -1078,7 +1079,7 @@ public void close() { } }; }); - StateFetcher localStateFetcher = stateFetcher.byteTrackingView(); + SideInputStateFetcher localSideInputStateFetcher = sideInputStateFetcher.byteTrackingView(); // If the read output KVs, then we can decode Windmill's byte key into a userland // key object and provide it to the execution context for use with per-key state. @@ -1114,7 +1115,7 @@ public void close() { outputDataWatermark, synchronizedProcessingTime, stateReader, - localStateFetcher, + localSideInputStateFetcher, outputBuilder); // Blocks while executing work. @@ -1184,7 +1185,7 @@ public void close() { shuffleBytesRead += message.getSerializedSize(); } } - long stateBytesRead = stateReader.getBytesRead() + localStateFetcher.getBytesRead(); + long stateBytesRead = stateReader.getBytesRead() + localSideInputStateFetcher.getBytesRead(); windmillShuffleBytesRead.addValue(shuffleBytesRead); windmillStateBytesRead.addValue(stateBytesRead); windmillStateBytesWritten.addValue(stateBytesWritten); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java index c8fa6e6dfb78..d630601c28a3 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java @@ -30,6 +30,7 @@ import java.util.List; import java.util.Map; import java.util.NavigableSet; +import java.util.Optional; import java.util.Set; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.atomic.AtomicLong; @@ -45,6 +46,9 @@ import org.apache.beam.runners.dataflow.worker.counters.CounterFactory; import org.apache.beam.runners.dataflow.worker.counters.NameContext; import org.apache.beam.runners.dataflow.worker.profiler.ScopedProfiler.ProfileScope; +import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInput; +import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputState; +import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataId; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest; @@ -62,7 +66,7 @@ import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Optional; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Supplier; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.FluentIterable; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.HashBasedTable; @@ -86,7 +90,7 @@ public class StreamingModeExecutionContext extends DataflowExecutionContext, Map> sideInputCache; + private final Map, Map>> sideInputCache; // Per-key cache of active Reader objects in use by this process. private final ImmutableMap stateNameMap; private final WindmillStateCache.ForComputation stateCache; @@ -104,7 +108,7 @@ public class StreamingModeExecutionContext extends DataflowExecutionContext activeReader; private volatile long backlogBytes; @@ -145,20 +149,20 @@ public void start( @Nullable Instant outputDataWatermark, @Nullable Instant synchronizedProcessingTime, WindmillStateReader stateReader, - StateFetcher stateFetcher, + SideInputStateFetcher sideInputStateFetcher, Windmill.WorkItemCommitRequest.Builder outputBuilder) { this.key = key; this.work = work; this.computationKey = WindmillComputationKey.create(computationId, work.getKey(), work.getShardingKey()); - this.stateFetcher = stateFetcher; + this.sideInputStateFetcher = sideInputStateFetcher; this.outputBuilder = outputBuilder; this.sideInputCache.clear(); clearSinkFullHint(); Instant processingTime = Instant.now(); // Ensure that the processing time is greater than any fired processing time - // timers. Otherwise a trigger could ignore the timer and orphan the window. + // timers. Otherwise, a trigger could ignore the timer and orphan the window. for (Windmill.Timer timer : work.getTimers().getTimersList()) { if (timer.getType() == Windmill.Timer.Type.REALTIME) { Instant inferredFiringTime = @@ -208,42 +212,67 @@ protected SideInputReader getSideInputReaderForViews( return StreamingModeSideInputReader.of(views, this); } + @SuppressWarnings("deprecation") + private TupleTag getInternalTag(PCollectionView view) { + return view.getTagInternal(); + } + /** * Fetches the requested sideInput, and maintains a view of the cache that doesn't remove items * until the active work item is finished. * - *

If the side input was not ready, throws {@code IllegalStateException} if the state is - * {@literal CACHED_IN_WORKITEM} or returns null otherwise. - * - *

If the side input was ready and null, returns {@literal Optional.absent()}. If the side - * input was ready and non-null returns {@literal Optional.present(...)}. + *

If the side input was not cached, throws {@code IllegalStateException} if the state is + * {@literal CACHED_IN_WORK_ITEM} or returns {@link SideInput} which contains {@link + * Optional}. */ - private @Nullable Optional fetchSideInput( + private SideInput fetchSideInput( + PCollectionView view, + BoundedWindow sideInputWindow, + @Nullable String stateFamily, + SideInputState state, + @Nullable Supplier scopedReadStateSupplier) { + TupleTag viewInternalTag = getInternalTag(view); + Map> tagCache = + sideInputCache.computeIfAbsent(viewInternalTag, k -> new HashMap<>()); + + @SuppressWarnings("unchecked") + Optional> cachedSideInput = + Optional.ofNullable((SideInput) tagCache.get(sideInputWindow)); + + if (cachedSideInput.isPresent()) { + return cachedSideInput.get(); + } + + if (state == SideInputState.CACHED_IN_WORK_ITEM) { + throw new IllegalStateException( + "Expected side input to be cached. Tag: " + viewInternalTag.getId()); + } + + return fetchSideInputFromWindmill( + view, + sideInputWindow, + Preconditions.checkNotNull(stateFamily), + state, + Preconditions.checkNotNull(scopedReadStateSupplier), + tagCache); + } + + private SideInput fetchSideInputFromWindmill( PCollectionView view, BoundedWindow sideInputWindow, String stateFamily, - StateFetcher.SideInputState state, - Supplier scopedReadStateSupplier) { - Map tagCache = - sideInputCache.computeIfAbsent(view.getTagInternal(), k -> new HashMap<>()); + SideInputState state, + Supplier scopedReadStateSupplier, + Map> tagCache) { + SideInput fetched = + sideInputStateFetcher.fetchSideInput( + view, sideInputWindow, stateFamily, state, scopedReadStateSupplier); - if (tagCache.containsKey(sideInputWindow)) { - @SuppressWarnings("unchecked") - T typed = (T) tagCache.get(sideInputWindow); - return Optional.fromNullable(typed); - } else { - if (state == StateFetcher.SideInputState.CACHED_IN_WORKITEM) { - throw new IllegalStateException( - "Expected side input to be cached. Tag: " + view.getTagInternal().getId()); - } - Optional fetched = - stateFetcher.fetchSideInput( - view, sideInputWindow, stateFamily, state, scopedReadStateSupplier); - if (fetched != null) { - tagCache.put(sideInputWindow, fetched.orNull()); - } - return fetched; + if (fetched.isReady()) { + tagCache.put(sideInputWindow, fetched); } + + return fetched; } public Iterable getSideInputNotifications() { @@ -378,8 +407,7 @@ String getStateFamily(NameContext nameContext) { interface StreamingModeStepContext { - boolean issueSideInputFetch( - PCollectionView view, BoundedWindow w, StateFetcher.SideInputState s); + boolean issueSideInputFetch(PCollectionView view, BoundedWindow w, SideInputState s); void addBlockingSideInput(Windmill.GlobalDataRequest blocked); @@ -412,10 +440,7 @@ public static class StreamingModeExecutionState // 2. The reporting thread calls extractUpdate which reads the current sum *AND* sets it to 0. private final AtomicLong totalMillisInState = new AtomicLong(); - // The worker that created this state. Used to report lulls back to the worker. - @SuppressWarnings("unused") // Affects a public api - private final StreamingDataflowWorker worker; - + @SuppressWarnings("unused") public StreamingModeExecutionState( NameContext nameContext, String stateName, @@ -424,7 +449,6 @@ public StreamingModeExecutionState( StreamingDataflowWorker worker) { // TODO: Take in the requesting step name and side input index for streaming. super(nameContext, stateName, null, null, metricsContainer, profileScope); - this.worker = worker; } /** @@ -513,8 +537,7 @@ public UserStepContext(StreamingModeExecutionContext.StepContext wrapped) { } @Override - public boolean issueSideInputFetch( - PCollectionView view, BoundedWindow w, StateFetcher.SideInputState s) { + public boolean issueSideInputFetch(PCollectionView view, BoundedWindow w, SideInputState s) { return wrapped.issueSideInputFetch(view, w, s); } @@ -609,9 +632,10 @@ public T get(PCollectionView view, BoundedWindow window) { view, window, null /* unused stateFamily */, - StateFetcher.SideInputState.CACHED_IN_WORKITEM, + SideInputState.CACHED_IN_WORK_ITEM, null /* unused readStateSupplier */) - .orNull(); + .value() + .orElse(null); } @Override @@ -883,10 +907,10 @@ public void writePCollectionViewData( /** Fetch the given side input asynchronously and return true if it is present. */ @Override public boolean issueSideInputFetch( - PCollectionView view, BoundedWindow mainInputWindow, StateFetcher.SideInputState state) { + PCollectionView view, BoundedWindow mainInputWindow, SideInputState state) { BoundedWindow sideInputWindow = view.getWindowMappingFn().getSideInputWindow(mainInputWindow); return fetchSideInput(view, sideInputWindow, stateFamily, state, scopedReadStateSupplier) - != null; + .isReady(); } /** Note that there is data on the current key that is blocked on the given side input. */ diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputFetcher.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputFetcher.java index 2b551acd2d8c..4f585e1c01b6 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputFetcher.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputFetcher.java @@ -33,6 +33,7 @@ import org.apache.beam.runners.core.TimerInternals.TimerData; import org.apache.beam.runners.core.TimerInternals.TimerDataCoder; import org.apache.beam.runners.core.TimerInternals.TimerDataCoderV2; +import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputState; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest; import org.apache.beam.sdk.coders.AtomicCoder; @@ -135,8 +136,7 @@ public Set getReadyWindows() { W window = entry.getKey(); boolean allSideInputsCached = true; for (PCollectionView view : sideInputViews.values()) { - if (!stepContext.issueSideInputFetch( - view, window, StateFetcher.SideInputState.KNOWN_READY)) { + if (!stepContext.issueSideInputFetch(view, window, SideInputState.KNOWN_READY)) { Windmill.GlobalDataRequest request = buildGlobalDataRequest(view, window); stepContext.addBlockingSideInput(request); windowBlockedSet.add(request); @@ -192,7 +192,7 @@ public boolean storeIfBlocked(WindowedValue elem) { Set blocked = blockedMap().get(window); if (blocked == null) { for (PCollectionView view : sideInputViews.values()) { - if (!stepContext.issueSideInputFetch(view, window, StateFetcher.SideInputState.UNKNOWN)) { + if (!stepContext.issueSideInputFetch(view, window, SideInputState.UNKNOWN)) { if (blocked == null) { blocked = new HashSet<>(); blockedMap().put(window, blocked); @@ -222,7 +222,7 @@ public boolean storeIfBlocked(TimerData timer) { boolean blocked = false; for (PCollectionView view : sideInputViews.values()) { - if (!stepContext.issueSideInputFetch(view, window, StateFetcher.SideInputState.UNKNOWN)) { + if (!stepContext.issueSideInputFetch(view, window, SideInputState.UNKNOWN)) { blocked = true; } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInput.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInput.java new file mode 100644 index 000000000000..0054e42782a4 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInput.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.streaming.sideinput; + +import com.google.auto.value.AutoValue; +import java.util.Optional; +import javax.annotation.Nullable; + +/** + * Entry in the side input cache that stores the value (null if not ready), and the encoded size of + * the value. + */ +@AutoValue +public abstract class SideInput { + static SideInput ready(@Nullable T value, int encodedSize) { + return new AutoValue_SideInput<>(true, Optional.ofNullable(value), encodedSize); + } + + static SideInput notReady() { + return new AutoValue_SideInput<>(false, Optional.empty(), 0); + } + + public abstract boolean isReady(); + + public abstract Optional value(); + + public abstract int size(); +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputCache.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputCache.java new file mode 100644 index 000000000000..7dd589840667 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputCache.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.streaming.sideinput; + +import com.google.auto.value.AutoValue; +import com.google.errorprone.annotations.CheckReturnValue; +import java.util.Optional; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.Cache; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheBuilder; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.Weigher; + +/** + * Wrapper around {@code Cache} that mostly delegates to the underlying + * cache, but adds threadsafe functionality to invalidate and load entries that are not ready. + * + * @implNote Returned values are explicitly cast, because the {@link #sideInputCache} holds wildcard + * types of all objects. + */ +@CheckReturnValue +@SuppressWarnings("unchecked") +final class SideInputCache { + + private static final long MAXIMUM_CACHE_WEIGHT = 100000000; /* 100 MB */ + private static final long CACHE_ENTRY_EXPIRY_MINUTES = 1L; + + private final Cache> sideInputCache; + + SideInputCache(Cache> sideInputCache) { + this.sideInputCache = sideInputCache; + } + + static SideInputCache create() { + return new SideInputCache( + CacheBuilder.newBuilder() + .maximumWeight(MAXIMUM_CACHE_WEIGHT) + .expireAfterWrite(CACHE_ENTRY_EXPIRY_MINUTES, TimeUnit.MINUTES) + .weigher((Weigher>) (id, entry) -> entry.size()) + .build()); + } + + synchronized SideInput invalidateThenLoadNewEntry( + Key key, Callable> cacheLoaderFn) throws ExecutionException { + // Invalidate the existing not-ready entry. This must be done atomically + // so that another thread doesn't replace the entry with a ready entry, which + // would then be deleted here. + SideInput newEntry = sideInputCache.getIfPresent(key); + if (newEntry != null && !newEntry.isReady()) { + sideInputCache.invalidate(key); + } + + return (SideInput) sideInputCache.get(key, cacheLoaderFn); + } + + Optional> get(Key key) { + return Optional.ofNullable((SideInput) sideInputCache.getIfPresent(key)); + } + + SideInput getOrLoad(Key key, Callable> cacheLoaderFn) + throws ExecutionException { + return (SideInput) sideInputCache.get(key, cacheLoaderFn); + } + + @AutoValue + abstract static class Key { + abstract TupleTag tag(); + + abstract BoundedWindow window(); + + static Key create(TupleTag tag, BoundedWindow window) { + return new AutoValue_SideInputCache_Key(tag, window); + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputState.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputState.java new file mode 100644 index 000000000000..d7af10d29e1f --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputState.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.streaming.sideinput; + +/** Indicates the caller's knowledge of whether a particular side input has been computed. */ +public enum SideInputState { + CACHED_IN_WORK_ITEM, + KNOWN_READY, + UNKNOWN +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcher.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcher.java new file mode 100644 index 000000000000..b0862fc31d06 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcher.java @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.streaming.sideinput; + +import static org.apache.beam.sdk.transforms.Materializations.ITERABLE_MATERIALIZATION_URN; +import static org.apache.beam.sdk.transforms.Materializations.MULTIMAP_MATERIALIZATION_URN; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; + +import java.io.Closeable; +import java.io.IOException; +import java.util.Collections; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.Callable; +import javax.annotation.concurrent.NotThreadSafe; +import org.apache.beam.runners.core.InMemoryMultimapSideInputView; +import org.apache.beam.runners.dataflow.worker.MetricTrackingWindmillServerStub; +import org.apache.beam.runners.dataflow.worker.WindmillTimeUtils; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalData; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.IterableCoder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.transforms.Materializations.IterableView; +import org.apache.beam.sdk.transforms.Materializations.MultimapView; +import org.apache.beam.sdk.transforms.ViewFn; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.util.ByteStringOutputStream; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Supplier; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Class responsible for fetching state from the windmill server. */ +@SuppressWarnings({ + "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) + "nullness" // TODO(https://github.com/apache/beam/issues/20497) +}) +@NotThreadSafe +public class SideInputStateFetcher { + private static final Logger LOG = LoggerFactory.getLogger(SideInputStateFetcher.class); + + private static final Set SUPPORTED_MATERIALIZATIONS = + ImmutableSet.of(ITERABLE_MATERIALIZATION_URN, MULTIMAP_MATERIALIZATION_URN); + + private final SideInputCache sideInputCache; + private final MetricTrackingWindmillServerStub server; + private long bytesRead = 0L; + + public SideInputStateFetcher(MetricTrackingWindmillServerStub server) { + this(server, SideInputCache.create()); + } + + SideInputStateFetcher(MetricTrackingWindmillServerStub server, SideInputCache sideInputCache) { + this.server = server; + this.sideInputCache = sideInputCache; + } + + @SuppressWarnings("deprecation") + private static Iterable decodeRawData(Coder viewInternalCoder, GlobalData data) + throws IOException { + return !data.getData().isEmpty() + ? IterableCoder.of(viewInternalCoder).decode(data.getData().newInput(), Coder.Context.OUTER) + : Collections.emptyList(); + } + + /** Returns a view of the underlying cache that keeps track of bytes read separately. */ + public SideInputStateFetcher byteTrackingView() { + return new SideInputStateFetcher(server, sideInputCache); + } + + public long getBytesRead() { + return bytesRead; + } + + /** + * Fetch the given side input, storing it in a process-level cache. + * + *

If state is KNOWN_READY, attempt to fetch the data regardless of whether a not-ready entry + * was cached. + * + *

Returns {@literal null} if the side input was not ready, {@literal Optional.absent()} if the + * side input was null, and {@literal Optional.present(...)} if the side input was non-null. + */ + @SuppressWarnings("deprecation") + public SideInput fetchSideInput( + PCollectionView view, + BoundedWindow sideWindow, + String stateFamily, + SideInputState state, + Supplier scopedReadStateSupplier) { + Callable> loadSideInputFromWindmill = + () -> loadSideInputFromWindmill(view, sideWindow, stateFamily, scopedReadStateSupplier); + + SideInputCache.Key sideInputCacheKey = + SideInputCache.Key.create(view.getTagInternal(), sideWindow); + + try { + if (state == SideInputState.KNOWN_READY) { + Optional> existingCacheEntry = sideInputCache.get(sideInputCacheKey); + if (!existingCacheEntry.isPresent()) { + return sideInputCache.getOrLoad(sideInputCacheKey, loadSideInputFromWindmill); + } + + if (!existingCacheEntry.get().isReady()) { + return sideInputCache.invalidateThenLoadNewEntry( + sideInputCacheKey, loadSideInputFromWindmill); + } + + return existingCacheEntry.get(); + } + + return sideInputCache.getOrLoad(sideInputCacheKey, loadSideInputFromWindmill); + } catch (Exception e) { + LOG.error("Fetch failed: ", e); + throw new RuntimeException("Exception while fetching side input: ", e); + } + } + + @SuppressWarnings({"deprecation", "unchecked"}) + private GlobalData fetchGlobalDataFromWindmill( + PCollectionView view, + SideWindowT sideWindow, + String stateFamily, + Supplier scopedReadStateSupplier) + throws IOException { + WindowingStrategy sideWindowStrategy = + (WindowingStrategy) view.getWindowingStrategyInternal(); + + Coder windowCoder = sideWindowStrategy.getWindowFn().windowCoder(); + + ByteStringOutputStream windowStream = new ByteStringOutputStream(); + windowCoder.encode(sideWindow, windowStream, Coder.Context.OUTER); + + Windmill.GlobalDataRequest request = + Windmill.GlobalDataRequest.newBuilder() + .setDataId( + Windmill.GlobalDataId.newBuilder() + .setTag(view.getTagInternal().getId()) + .setVersion(windowStream.toByteString()) + .build()) + .setStateFamily(stateFamily) + .setExistenceWatermarkDeadline( + WindmillTimeUtils.harnessToWindmillTimestamp( + sideWindowStrategy.getTrigger().getWatermarkThatGuaranteesFiring(sideWindow))) + .build(); + + try (Closeable ignored = scopedReadStateSupplier.get()) { + return server.getSideInputData(request); + } + } + + @SuppressWarnings("deprecation") + private SideInput loadSideInputFromWindmill( + PCollectionView view, + BoundedWindow sideWindow, + String stateFamily, + Supplier scopedReadStateSupplier) + throws IOException { + checkState( + SUPPORTED_MATERIALIZATIONS.contains(view.getViewFn().getMaterialization().getUrn()), + "Only materialization's of type %s supported, received %s", + SUPPORTED_MATERIALIZATIONS, + view.getViewFn().getMaterialization().getUrn()); + + GlobalData data = + fetchGlobalDataFromWindmill(view, sideWindow, stateFamily, scopedReadStateSupplier); + bytesRead += data.getSerializedSize(); + return data.getIsReady() ? createSideInputCacheEntry(view, data) : SideInput.notReady(); + } + + @SuppressWarnings({"deprecation", "unchecked"}) + private SideInput createSideInputCacheEntry(PCollectionView view, GlobalData data) + throws IOException { + Iterable rawData = decodeRawData(view.getCoderInternal(), data); + switch (view.getViewFn().getMaterialization().getUrn()) { + case ITERABLE_MATERIALIZATION_URN: + { + ViewFn viewFn = (ViewFn) view.getViewFn(); + return SideInput.ready(viewFn.apply(() -> rawData), data.getData().size()); + } + case MULTIMAP_MATERIALIZATION_URN: + { + ViewFn viewFn = (ViewFn) view.getViewFn(); + Coder keyCoder = ((KvCoder) view.getCoderInternal()).getKeyCoder(); + return SideInput.ready( + viewFn.apply( + InMemoryMultimapSideInputView.fromIterable(keyCoder, (Iterable) rawData)), + data.getData().size()); + } + default: + throw new IllegalStateException( + String.format( + "Unknown side input materialization format requested '%s'", + view.getViewFn().getMaterialization().getUrn())); + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java index 24e6e2795c68..dcd0edf50b78 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java @@ -279,11 +279,7 @@ private static CounterUpdate getCounter(Iterable counters, String } static Work createMockWork(long workToken) { - return Work.create( - Windmill.WorkItem.newBuilder().setKey(ByteString.EMPTY).setWorkToken(workToken).build(), - Instant::now, - Collections.emptyList(), - work -> {}); + return createMockWork(workToken, work -> {}); } static Work createMockWork(long workToken, Consumer processWorkFn) { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java index 6620dbdaab79..9991520d593b 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java @@ -55,6 +55,7 @@ import org.apache.beam.runners.dataflow.worker.counters.NameContext; import org.apache.beam.runners.dataflow.worker.profiler.ScopedProfiler.NoopProfileScope; import org.apache.beam.runners.dataflow.worker.profiler.ScopedProfiler.ProfileScope; +import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateReader; @@ -83,10 +84,10 @@ @RunWith(JUnit4.class) public class StreamingModeExecutionContextTest { - @Mock private StateFetcher stateFetcher; + @Mock private SideInputStateFetcher sideInputStateFetcher; @Mock private WindmillStateReader stateReader; - private StreamingModeExecutionStateRegistry executionStateRegistry = + private final StreamingModeExecutionStateRegistry executionStateRegistry = new StreamingModeExecutionStateRegistry(null); private StreamingModeExecutionContext executionContext; DataflowWorkerHarnessOptions options; @@ -133,7 +134,7 @@ public void testTimerInternalsSetTimer() { null, // output watermark null, // synchronized processing time stateReader, - stateFetcher, + sideInputStateFetcher, outputBuilder); TimerInternals timerInternals = stepContext.timerInternals(); @@ -183,7 +184,7 @@ public void testTimerInternalsProcessingTimeSkew() { null, // output watermark null, // synchronized processing time stateReader, - stateFetcher, + sideInputStateFetcher, outputBuilder); TimerInternals timerInternals = stepContext.timerInternals(); assertTrue(timerTimestamp.isBefore(timerInternals.currentProcessingTime())); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputDoFnRunnerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputDoFnRunnerTest.java index 05e0ff417615..3c121ab27f76 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputDoFnRunnerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputDoFnRunnerTest.java @@ -39,7 +39,7 @@ import org.apache.beam.runners.core.SideInputReader; import org.apache.beam.runners.core.StateInternals; import org.apache.beam.runners.core.StateNamespaces; -import org.apache.beam.runners.dataflow.worker.StateFetcher.SideInputState; +import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputState; import org.apache.beam.runners.dataflow.worker.util.ListOutputManager; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest; diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputFetcherTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputFetcherTest.java index 9ce462be3211..a7196613fbb1 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputFetcherTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputFetcherTest.java @@ -31,7 +31,7 @@ import org.apache.beam.runners.core.StateInternals; import org.apache.beam.runners.core.StateNamespaces; import org.apache.beam.runners.core.TimerInternals.TimerData; -import org.apache.beam.runners.dataflow.worker.StateFetcher.SideInputState; +import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputState; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.state.BagState; diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StateFetcherTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcherTest.java similarity index 70% rename from runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StateFetcherTest.java rename to runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcherTest.java index 13d8a9bd3ffb..d9eee8bf519d 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StateFetcherTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcherTest.java @@ -15,11 +15,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.beam.runners.dataflow.worker; +package org.apache.beam.runners.dataflow.worker.streaming.sideinput; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.contains; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; import static org.mockito.Matchers.any; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -27,10 +29,10 @@ import static org.mockito.Mockito.when; import java.io.Closeable; -import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.concurrent.TimeUnit; -import org.apache.beam.runners.dataflow.worker.StateFetcher.SideInputState; +import org.apache.beam.runners.dataflow.worker.MetricTrackingWindmillServerStub; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.ListCoder; @@ -56,9 +58,10 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; -/** Unit tests for {@link StateFetcher}. */ +/** Unit tests for {@link SideInputStateFetcher}. */ +@SuppressWarnings("deprecation") @RunWith(JUnit4.class) -public class StateFetcherTest { +public class SideInputStateFetcherTest { private static final String STATE_FAMILY = "state"; @Mock MetricTrackingWindmillServerStub server; @@ -72,10 +75,11 @@ public void setUp() { @Test public void testFetchGlobalDataBasic() throws Exception { - StateFetcher fetcher = new StateFetcher(server); + SideInputStateFetcher fetcher = new SideInputStateFetcher(server); ByteStringOutputStream stream = new ByteStringOutputStream(); - ListCoder.of(StringUtf8Coder.of()).encode(Arrays.asList("data"), stream, Coder.Context.OUTER); + ListCoder.of(StringUtf8Coder.of()) + .encode(Collections.singletonList("data"), stream, Coder.Context.OUTER); ByteString encodedIterable = stream.toByteString(); PCollectionView view = @@ -87,17 +91,29 @@ public void testFetchGlobalDataBasic() throws Exception { // then the data is already cached. when(server.getSideInputData(any(Windmill.GlobalDataRequest.class))) .thenReturn( - buildGlobalDataResponse(tag, ByteString.EMPTY, false, null), - buildGlobalDataResponse(tag, ByteString.EMPTY, true, encodedIterable)); + buildGlobalDataResponse(tag, false, null), + buildGlobalDataResponse(tag, true, encodedIterable)); + + assertFalse( + fetcher + .fetchSideInput( + view, + GlobalWindow.INSTANCE, + STATE_FAMILY, + SideInputState.UNKNOWN, + readStateSupplier) + .isReady()); + + assertFalse( + fetcher + .fetchSideInput( + view, + GlobalWindow.INSTANCE, + STATE_FAMILY, + SideInputState.UNKNOWN, + readStateSupplier) + .isReady()); - assertEquals( - null, - fetcher.fetchSideInput( - view, GlobalWindow.INSTANCE, STATE_FAMILY, SideInputState.UNKNOWN, readStateSupplier)); - assertEquals( - null, - fetcher.fetchSideInput( - view, GlobalWindow.INSTANCE, STATE_FAMILY, SideInputState.UNKNOWN, readStateSupplier)); assertEquals( "data", fetcher @@ -107,7 +123,8 @@ public void testFetchGlobalDataBasic() throws Exception { STATE_FAMILY, SideInputState.KNOWN_READY, readStateSupplier) - .orNull()); + .value() + .orElse(null)); assertEquals( "data", fetcher @@ -117,18 +134,20 @@ public void testFetchGlobalDataBasic() throws Exception { STATE_FAMILY, SideInputState.KNOWN_READY, readStateSupplier) - .orNull()); + .value() + .orElse(null)); - verify(server, times(2)).getSideInputData(buildGlobalDataRequest(tag, ByteString.EMPTY)); + verify(server, times(2)).getSideInputData(buildGlobalDataRequest(tag)); verifyNoMoreInteractions(server); } @Test public void testFetchGlobalDataNull() throws Exception { - StateFetcher fetcher = new StateFetcher(server); + SideInputStateFetcher fetcher = new SideInputStateFetcher(server); ByteStringOutputStream stream = new ByteStringOutputStream(); - ListCoder.of(VoidCoder.of()).encode(Arrays.asList((Void) null), stream, Coder.Context.OUTER); + ListCoder.of(VoidCoder.of()) + .encode(Collections.singletonList(null), stream, Coder.Context.OUTER); ByteString encodedIterable = stream.toByteString(); PCollectionView view = @@ -140,19 +159,28 @@ public void testFetchGlobalDataNull() throws Exception { // then the data is already cached. when(server.getSideInputData(any(Windmill.GlobalDataRequest.class))) .thenReturn( - buildGlobalDataResponse(tag, ByteString.EMPTY, false, null), - buildGlobalDataResponse(tag, ByteString.EMPTY, true, encodedIterable)); + buildGlobalDataResponse(tag, false, null), + buildGlobalDataResponse(tag, true, encodedIterable)); - assertEquals( - null, - fetcher.fetchSideInput( - view, GlobalWindow.INSTANCE, STATE_FAMILY, SideInputState.UNKNOWN, readStateSupplier)); - assertEquals( - null, - fetcher.fetchSideInput( - view, GlobalWindow.INSTANCE, STATE_FAMILY, SideInputState.UNKNOWN, readStateSupplier)); - assertEquals( - null, + assertFalse( + fetcher + .fetchSideInput( + view, + GlobalWindow.INSTANCE, + STATE_FAMILY, + SideInputState.UNKNOWN, + readStateSupplier) + .isReady()); + assertFalse( + fetcher + .fetchSideInput( + view, + GlobalWindow.INSTANCE, + STATE_FAMILY, + SideInputState.UNKNOWN, + readStateSupplier) + .isReady()); + assertNull( fetcher .fetchSideInput( view, @@ -160,9 +188,9 @@ public void testFetchGlobalDataNull() throws Exception { STATE_FAMILY, SideInputState.KNOWN_READY, readStateSupplier) - .orNull()); - assertEquals( - null, + .value() + .orElse(null)); + assertNull( fetcher .fetchSideInput( view, @@ -170,9 +198,10 @@ public void testFetchGlobalDataNull() throws Exception { STATE_FAMILY, SideInputState.KNOWN_READY, readStateSupplier) - .orNull()); + .value() + .orElse(null)); - verify(server, times(2)).getSideInputData(buildGlobalDataRequest(tag, ByteString.EMPTY)); + verify(server, times(2)).getSideInputData(buildGlobalDataRequest(tag)); verifyNoMoreInteractions(server); } @@ -181,15 +210,14 @@ public void testFetchGlobalDataCacheOverflow() throws Exception { Coder> coder = ListCoder.of(StringUtf8Coder.of()); ByteStringOutputStream stream = new ByteStringOutputStream(); - coder.encode(Arrays.asList("data1"), stream, Coder.Context.OUTER); + coder.encode(Collections.singletonList("data1"), stream, Coder.Context.OUTER); ByteString encodedIterable1 = stream.toByteStringAndReset(); - coder.encode(Arrays.asList("data2"), stream, Coder.Context.OUTER); + coder.encode(Collections.singletonList("data2"), stream, Coder.Context.OUTER); ByteString encodedIterable2 = stream.toByteString(); - Cache cache = - CacheBuilder.newBuilder().build(); + Cache> cache = CacheBuilder.newBuilder().build(); - StateFetcher fetcher = new StateFetcher(server, cache); + SideInputStateFetcher fetcher = new SideInputStateFetcher(server, new SideInputCache(cache)); PCollectionView view1 = TestPipeline.create().apply(Create.empty(StringUtf8Coder.of())).apply(View.asSingleton()); @@ -204,9 +232,9 @@ public void testFetchGlobalDataCacheOverflow() throws Exception { // then view 1 again twice. when(server.getSideInputData(any(Windmill.GlobalDataRequest.class))) .thenReturn( - buildGlobalDataResponse(tag1, ByteString.EMPTY, true, encodedIterable1), - buildGlobalDataResponse(tag2, ByteString.EMPTY, true, encodedIterable2), - buildGlobalDataResponse(tag1, ByteString.EMPTY, true, encodedIterable1)); + buildGlobalDataResponse(tag1, true, encodedIterable1), + buildGlobalDataResponse(tag2, true, encodedIterable2), + buildGlobalDataResponse(tag1, true, encodedIterable1)); assertEquals( "data1", @@ -217,7 +245,8 @@ public void testFetchGlobalDataCacheOverflow() throws Exception { STATE_FAMILY, SideInputState.UNKNOWN, readStateSupplier) - .orNull()); + .value() + .orElse(null)); assertEquals( "data2", fetcher @@ -227,7 +256,8 @@ public void testFetchGlobalDataCacheOverflow() throws Exception { STATE_FAMILY, SideInputState.UNKNOWN, readStateSupplier) - .orNull()); + .value() + .orElse(null)); cache.invalidateAll(); assertEquals( "data1", @@ -238,7 +268,8 @@ public void testFetchGlobalDataCacheOverflow() throws Exception { STATE_FAMILY, SideInputState.UNKNOWN, readStateSupplier) - .orNull()); + .value() + .orElse(null)); assertEquals( "data1", fetcher @@ -248,7 +279,8 @@ public void testFetchGlobalDataCacheOverflow() throws Exception { STATE_FAMILY, SideInputState.UNKNOWN, readStateSupplier) - .orNull()); + .value() + .orElse(null)); ArgumentCaptor captor = ArgumentCaptor.forClass(Windmill.GlobalDataRequest.class); @@ -259,14 +291,14 @@ public void testFetchGlobalDataCacheOverflow() throws Exception { assertThat( captor.getAllValues(), contains( - buildGlobalDataRequest(tag1, ByteString.EMPTY), - buildGlobalDataRequest(tag2, ByteString.EMPTY), - buildGlobalDataRequest(tag1, ByteString.EMPTY))); + buildGlobalDataRequest(tag1), + buildGlobalDataRequest(tag2), + buildGlobalDataRequest(tag1))); } @Test public void testEmptyFetchGlobalData() throws Exception { - StateFetcher fetcher = new StateFetcher(server); + SideInputStateFetcher fetcher = new SideInputStateFetcher(server); ByteString encodedIterable = ByteString.EMPTY; @@ -280,7 +312,7 @@ public void testEmptyFetchGlobalData() throws Exception { // Test three calls in a row. First, data is not ready, then data is ready, // then the data is already cached. when(server.getSideInputData(any(Windmill.GlobalDataRequest.class))) - .thenReturn(buildGlobalDataResponse(tag, ByteString.EMPTY, true, encodedIterable)); + .thenReturn(buildGlobalDataResponse(tag, true, encodedIterable)); assertEquals( 0L, @@ -292,17 +324,22 @@ public void testEmptyFetchGlobalData() throws Exception { STATE_FAMILY, SideInputState.UNKNOWN, readStateSupplier) - .orNull()); + .value() + .orElse(null)); - verify(server).getSideInputData(buildGlobalDataRequest(tag, ByteString.EMPTY)); + verify(server).getSideInputData(buildGlobalDataRequest(tag)); verifyNoMoreInteractions(server); } private Windmill.GlobalData buildGlobalDataResponse( - String tag, ByteString version, boolean isReady, ByteString data) { + String tag, boolean isReady, ByteString data) { Windmill.GlobalData.Builder builder = Windmill.GlobalData.newBuilder() - .setDataId(Windmill.GlobalDataId.newBuilder().setTag(tag).setVersion(version).build()); + .setDataId( + Windmill.GlobalDataId.newBuilder() + .setTag(tag) + .setVersion(ByteString.EMPTY) + .build()); if (isReady) { builder.setIsReady(true).setData(data); @@ -312,9 +349,9 @@ private Windmill.GlobalData buildGlobalDataResponse( return builder.build(); } - private Windmill.GlobalDataRequest buildGlobalDataRequest(String tag, ByteString version) { + private Windmill.GlobalDataRequest buildGlobalDataRequest(String tag) { Windmill.GlobalDataId id = - Windmill.GlobalDataId.newBuilder().setTag(tag).setVersion(version).build(); + Windmill.GlobalDataId.newBuilder().setTag(tag).setVersion(ByteString.EMPTY).build(); return Windmill.GlobalDataRequest.newBuilder() .setDataId(id) From 7732ecf99b187304505e132955e358eaeb8c0524 Mon Sep 17 00:00:00 2001 From: Martin Trieu Date: Thu, 5 Oct 2023 16:43:17 -0700 Subject: [PATCH 2/2] make @SuppressWarnings annotations local and with comments --- .../worker/streaming/sideinput/SideInput.java | 11 +- .../streaming/sideinput/SideInputCache.java | 50 ++++++--- .../sideinput/SideInputStateFetcher.java | 104 +++++++++++------- .../sideinput/SideInputStateFetcherTest.java | 17 ++- 4 files changed, 122 insertions(+), 60 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInput.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInput.java index 0054e42782a4..04eecadc1e5c 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInput.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInput.java @@ -22,8 +22,15 @@ import javax.annotation.Nullable; /** - * Entry in the side input cache that stores the value (null if not ready), and the encoded size of - * the value. + * Entry in the side input cache that stores the value and the encoded size of the value. + * + *

Can be in 1 of 3 states: + * + *

    + *
  • Ready with a value. + *
  • Ready with no value, represented as {@link Optional} + *
  • Not ready. + *
*/ @AutoValue public abstract class SideInput { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputCache.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputCache.java index 7dd589840667..721c477435ef 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputCache.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputCache.java @@ -25,6 +25,7 @@ import java.util.concurrent.TimeUnit; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.Cache; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheBuilder; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.Weigher; @@ -37,15 +38,14 @@ * types of all objects. */ @CheckReturnValue -@SuppressWarnings("unchecked") final class SideInputCache { private static final long MAXIMUM_CACHE_WEIGHT = 100000000; /* 100 MB */ private static final long CACHE_ENTRY_EXPIRY_MINUTES = 1L; - private final Cache> sideInputCache; + private final Cache, SideInput> sideInputCache; - SideInputCache(Cache> sideInputCache) { + SideInputCache(Cache, SideInput> sideInputCache) { this.sideInputCache = sideInputCache; } @@ -54,40 +54,60 @@ static SideInputCache create() { CacheBuilder.newBuilder() .maximumWeight(MAXIMUM_CACHE_WEIGHT) .expireAfterWrite(CACHE_ENTRY_EXPIRY_MINUTES, TimeUnit.MINUTES) - .weigher((Weigher>) (id, entry) -> entry.size()) + .weigher((Weigher, SideInput>) (id, entry) -> entry.size()) .build()); } synchronized SideInput invalidateThenLoadNewEntry( - Key key, Callable> cacheLoaderFn) throws ExecutionException { + Key key, Callable> cacheLoaderFn) throws ExecutionException { // Invalidate the existing not-ready entry. This must be done atomically // so that another thread doesn't replace the entry with a ready entry, which // would then be deleted here. - SideInput newEntry = sideInputCache.getIfPresent(key); - if (newEntry != null && !newEntry.isReady()) { + Optional> newEntry = getIfPresentUnchecked(key); + if (newEntry.isPresent() && !newEntry.get().isReady()) { sideInputCache.invalidate(key); } - return (SideInput) sideInputCache.get(key, cacheLoaderFn); + return getUnchecked(key, cacheLoaderFn); } - Optional> get(Key key) { - return Optional.ofNullable((SideInput) sideInputCache.getIfPresent(key)); + Optional> get(Key key) { + return getIfPresentUnchecked(key); + } + + SideInput getOrLoad(Key key, Callable> cacheLoaderFn) + throws ExecutionException { + return getUnchecked(key, cacheLoaderFn); } - SideInput getOrLoad(Key key, Callable> cacheLoaderFn) + @SuppressWarnings({ + "unchecked" // cacheLoaderFn loads SideInput, and key is of type T, so value for Key is + // always SideInput. + }) + private SideInput getUnchecked(Key key, Callable> cacheLoaderFn) throws ExecutionException { return (SideInput) sideInputCache.get(key, cacheLoaderFn); } + @SuppressWarnings({ + "unchecked" // cacheLoaderFn loads SideInput, and key is of type T, so value for Key is + // always SideInput. + }) + private Optional> getIfPresentUnchecked(Key key) { + return Optional.ofNullable((SideInput) sideInputCache.getIfPresent(key)); + } + @AutoValue - abstract static class Key { + abstract static class Key { + static Key create( + TupleTag tag, BoundedWindow window, TypeDescriptor typeDescriptor) { + return new AutoValue_SideInputCache_Key<>(tag, window, typeDescriptor); + } + abstract TupleTag tag(); abstract BoundedWindow window(); - static Key create(TupleTag tag, BoundedWindow window) { - return new AutoValue_SideInputCache_Key(tag, window); - } + abstract TypeDescriptor typeDescriptor(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcher.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcher.java index b0862fc31d06..aa61c4219353 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcher.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcher.java @@ -42,6 +42,7 @@ import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.ByteStringOutputStream; import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.WindowingStrategy; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Supplier; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; @@ -49,10 +50,6 @@ import org.slf4j.LoggerFactory; /** Class responsible for fetching state from the windmill server. */ -@SuppressWarnings({ - "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) @NotThreadSafe public class SideInputStateFetcher { private static final Logger LOG = LoggerFactory.getLogger(SideInputStateFetcher.class); @@ -73,14 +70,32 @@ public SideInputStateFetcher(MetricTrackingWindmillServerStub server) { this.sideInputCache = sideInputCache; } - @SuppressWarnings("deprecation") - private static Iterable decodeRawData(Coder viewInternalCoder, GlobalData data) + private static Iterable decodeRawData(PCollectionView view, GlobalData data) throws IOException { return !data.getData().isEmpty() - ? IterableCoder.of(viewInternalCoder).decode(data.getData().newInput(), Coder.Context.OUTER) + ? IterableCoder.of(getCoder(view)).decode(data.getData().newInput()) : Collections.emptyList(); } + @SuppressWarnings({ + "deprecation" // Required as part of the SideInputCacheKey, and not exposed. + }) + private static TupleTag getInternalTag(PCollectionView view) { + return view.getTagInternal(); + } + + @SuppressWarnings("deprecation") + private static ViewFn getViewFn(PCollectionView view) { + return view.getViewFn(); + } + + @SuppressWarnings({ + "deprecation" // The view's internal coder is required to decode the raw data. + }) + private static Coder getCoder(PCollectionView view) { + return view.getCoderInternal(); + } + /** Returns a view of the underlying cache that keeps track of bytes read separately. */ public SideInputStateFetcher byteTrackingView() { return new SideInputStateFetcher(server, sideInputCache); @@ -95,11 +110,7 @@ public long getBytesRead() { * *

If state is KNOWN_READY, attempt to fetch the data regardless of whether a not-ready entry * was cached. - * - *

Returns {@literal null} if the side input was not ready, {@literal Optional.absent()} if the - * side input was null, and {@literal Optional.present(...)} if the side input was non-null. */ - @SuppressWarnings("deprecation") public SideInput fetchSideInput( PCollectionView view, BoundedWindow sideWindow, @@ -108,9 +119,9 @@ public SideInput fetchSideInput( Supplier scopedReadStateSupplier) { Callable> loadSideInputFromWindmill = () -> loadSideInputFromWindmill(view, sideWindow, stateFamily, scopedReadStateSupplier); - - SideInputCache.Key sideInputCacheKey = - SideInputCache.Key.create(view.getTagInternal(), sideWindow); + SideInputCache.Key sideInputCacheKey = + SideInputCache.Key.create( + getInternalTag(view), sideWindow, getViewFn(view).getTypeDescriptor()); try { if (state == SideInputState.KNOWN_READY) { @@ -134,26 +145,29 @@ public SideInput fetchSideInput( } } - @SuppressWarnings({"deprecation", "unchecked"}) private GlobalData fetchGlobalDataFromWindmill( PCollectionView view, SideWindowT sideWindow, String stateFamily, Supplier scopedReadStateSupplier) throws IOException { + @SuppressWarnings({ + "deprecation", // Internal windowStrategy is required to fetch side input data from Windmill. + "unchecked" // Internal windowing strategy matches WindowingStrategy. + }) WindowingStrategy sideWindowStrategy = (WindowingStrategy) view.getWindowingStrategyInternal(); Coder windowCoder = sideWindowStrategy.getWindowFn().windowCoder(); ByteStringOutputStream windowStream = new ByteStringOutputStream(); - windowCoder.encode(sideWindow, windowStream, Coder.Context.OUTER); + windowCoder.encode(sideWindow, windowStream); Windmill.GlobalDataRequest request = Windmill.GlobalDataRequest.newBuilder() .setDataId( Windmill.GlobalDataId.newBuilder() - .setTag(view.getTagInternal().getId()) + .setTag(getInternalTag(view).getId()) .setVersion(windowStream.toByteString()) .build()) .setStateFamily(stateFamily) @@ -167,49 +181,65 @@ private GlobalData fetchGlobalDataFromWin } } - @SuppressWarnings("deprecation") private SideInput loadSideInputFromWindmill( PCollectionView view, BoundedWindow sideWindow, String stateFamily, Supplier scopedReadStateSupplier) throws IOException { - checkState( - SUPPORTED_MATERIALIZATIONS.contains(view.getViewFn().getMaterialization().getUrn()), - "Only materialization's of type %s supported, received %s", - SUPPORTED_MATERIALIZATIONS, - view.getViewFn().getMaterialization().getUrn()); - + validateViewMaterialization(view); GlobalData data = fetchGlobalDataFromWindmill(view, sideWindow, stateFamily, scopedReadStateSupplier); bytesRead += data.getSerializedSize(); return data.getIsReady() ? createSideInputCacheEntry(view, data) : SideInput.notReady(); } - @SuppressWarnings({"deprecation", "unchecked"}) + private void validateViewMaterialization(PCollectionView view) { + String materializationUrn = getViewFn(view).getMaterialization().getUrn(); + checkState( + SUPPORTED_MATERIALIZATIONS.contains(materializationUrn), + "Only materialization's of type %s supported, received %s", + SUPPORTED_MATERIALIZATIONS, + materializationUrn); + } + private SideInput createSideInputCacheEntry(PCollectionView view, GlobalData data) throws IOException { - Iterable rawData = decodeRawData(view.getCoderInternal(), data); - switch (view.getViewFn().getMaterialization().getUrn()) { + Iterable rawData = decodeRawData(view, data); + switch (getViewFn(view).getMaterialization().getUrn()) { case ITERABLE_MATERIALIZATION_URN: { - ViewFn viewFn = (ViewFn) view.getViewFn(); + @SuppressWarnings({ + "unchecked", // ITERABLE_MATERIALIZATION_URN has ViewFn. + "rawtypes" // TODO(https://github.com/apache/beam/issues/20447) + }) + ViewFn viewFn = (ViewFn) getViewFn(view); return SideInput.ready(viewFn.apply(() -> rawData), data.getData().size()); } case MULTIMAP_MATERIALIZATION_URN: { - ViewFn viewFn = (ViewFn) view.getViewFn(); - Coder keyCoder = ((KvCoder) view.getCoderInternal()).getKeyCoder(); - return SideInput.ready( + @SuppressWarnings({ + "unchecked", // MULTIMAP_MATERIALIZATION_URN has ViewFn. + "rawtypes" // TODO(https://github.com/apache/beam/issues/20447) + }) + ViewFn viewFn = (ViewFn) getViewFn(view); + Coder keyCoder = ((KvCoder) getCoder(view)).getKeyCoder(); + + @SuppressWarnings({ + "unchecked", // Safe since multimap rawData is of type Iterable> + "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) + }) + T multimapSideInputValue = viewFn.apply( - InMemoryMultimapSideInputView.fromIterable(keyCoder, (Iterable) rawData)), - data.getData().size()); + InMemoryMultimapSideInputView.fromIterable(keyCoder, (Iterable) rawData)); + return SideInput.ready(multimapSideInputValue, data.getData().size()); } default: - throw new IllegalStateException( - String.format( - "Unknown side input materialization format requested '%s'", - view.getViewFn().getMaterialization().getUrn())); + { + throw new IllegalStateException( + "Unknown side input materialization format requested: " + + getViewFn(view).getMaterialization().getUrn()); + } } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcherTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcherTest.java index d9eee8bf519d..daf814618791 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcherTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcherTest.java @@ -59,14 +59,15 @@ import org.mockito.MockitoAnnotations; /** Unit tests for {@link SideInputStateFetcher}. */ +// TODO: Add tests with different encoded windows to verify version is correctly plumbed. @SuppressWarnings("deprecation") @RunWith(JUnit4.class) public class SideInputStateFetcherTest { private static final String STATE_FAMILY = "state"; - @Mock MetricTrackingWindmillServerStub server; + @Mock private MetricTrackingWindmillServerStub server; - @Mock Supplier readStateSupplier; + @Mock private Supplier readStateSupplier; @Before public void setUp() { @@ -215,7 +216,7 @@ public void testFetchGlobalDataCacheOverflow() throws Exception { coder.encode(Collections.singletonList("data2"), stream, Coder.Context.OUTER); ByteString encodedIterable2 = stream.toByteString(); - Cache> cache = CacheBuilder.newBuilder().build(); + Cache, SideInput> cache = CacheBuilder.newBuilder().build(); SideInputStateFetcher fetcher = new SideInputStateFetcher(server, new SideInputCache(cache)); @@ -331,7 +332,7 @@ public void testEmptyFetchGlobalData() throws Exception { verifyNoMoreInteractions(server); } - private Windmill.GlobalData buildGlobalDataResponse( + private static Windmill.GlobalData buildGlobalDataResponse( String tag, boolean isReady, ByteString data) { Windmill.GlobalData.Builder builder = Windmill.GlobalData.newBuilder() @@ -349,9 +350,9 @@ private Windmill.GlobalData buildGlobalDataResponse( return builder.build(); } - private Windmill.GlobalDataRequest buildGlobalDataRequest(String tag) { + private static Windmill.GlobalDataRequest buildGlobalDataRequest(String tag, ByteString version) { Windmill.GlobalDataId id = - Windmill.GlobalDataId.newBuilder().setTag(tag).setVersion(ByteString.EMPTY).build(); + Windmill.GlobalDataId.newBuilder().setTag(tag).setVersion(version).build(); return Windmill.GlobalDataRequest.newBuilder() .setDataId(id) @@ -360,4 +361,8 @@ private Windmill.GlobalDataRequest buildGlobalDataRequest(String tag) { TimeUnit.MILLISECONDS.toMicros(GlobalWindow.INSTANCE.maxTimestamp().getMillis())) .build(); } + + private static Windmill.GlobalDataRequest buildGlobalDataRequest(String tag) { + return buildGlobalDataRequest(tag, ByteString.EMPTY); + } }