diff --git a/CHANGES.md b/CHANGES.md index 38fa6e44b73d..0a620038f11e 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -68,6 +68,7 @@ * Multiple RunInference instances can now share the same model instance by setting the model_identifier parameter (Python) ([#31665](https://github.com/apache/beam/issues/31665)). * Removed a 3rd party LGPL dependency from the Go SDK ([#31765](https://github.com/apache/beam/issues/31765)). +* Support for MapState and SetState when using Dataflow Runner v1 with Streaming Engine (Java) ([[#18200](https://github.com/apache/beam/issues/18200)]) ## Breaking Changes diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTags.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTags.java index 7ffb10c85c0a..6ed7f8525fdc 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTags.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTags.java @@ -257,6 +257,14 @@ public static StateTag> convertToMapTagInternal( new StructuredId(setTag.getId()), StateSpecs.convertToMapSpecInternal(setTag.getSpec())); } + public static StateTag> convertToMultiMapTagInternal( + StateTag> mapTag) { + StateSpec> spec = mapTag.getSpec(); + StateSpec> multimapSpec = + StateSpecs.convertToMultimapSpecInternal(spec); + return new SimpleStateTag<>(new StructuredId(mapTag.getId()), multimapSpec); + } + private static class StructuredId implements Serializable { private final StateKind kind; private final String rawId; diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java index de566599bf88..708c63413268 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java @@ -2564,11 +2564,6 @@ static boolean useUnifiedWorker(DataflowPipelineOptions options) { || hasExperiment(options, "use_portable_job_submission"); } - static boolean useStreamingEngine(DataflowPipelineOptions options) { - return hasExperiment(options, GcpOptions.STREAMING_ENGINE_EXPERIMENT) - || hasExperiment(options, GcpOptions.WINDMILL_SERVICE_EXPERIMENT); - } - static void verifyDoFnSupported( DoFn fn, boolean streaming, DataflowPipelineOptions options) { if (!streaming && DoFnSignatures.usesMultimapState(fn)) { @@ -2583,8 +2578,6 @@ static void verifyDoFnSupported( "%s does not currently support @RequiresTimeSortedInput in streaming mode.", DataflowRunner.class.getSimpleName())); } - - boolean streamingEngine = useStreamingEngine(options); boolean isUnifiedWorker = useUnifiedWorker(options); if (DoFnSignatures.usesMultimapState(fn) && isUnifiedWorker) { @@ -2593,25 +2586,17 @@ static void verifyDoFnSupported( "%s does not currently support %s running using streaming on unified worker", DataflowRunner.class.getSimpleName(), MultimapState.class.getSimpleName())); } - if (DoFnSignatures.usesSetState(fn)) { - if (streaming && (isUnifiedWorker || streamingEngine)) { - throw new UnsupportedOperationException( - String.format( - "%s does not currently support %s when using %s", - DataflowRunner.class.getSimpleName(), - SetState.class.getSimpleName(), - isUnifiedWorker ? "streaming on unified worker" : "streaming engine")); - } + if (DoFnSignatures.usesSetState(fn) && streaming && isUnifiedWorker) { + throw new UnsupportedOperationException( + String.format( + "%s does not currently support %s when using streaming on unified worker", + DataflowRunner.class.getSimpleName(), SetState.class.getSimpleName())); } - if (DoFnSignatures.usesMapState(fn)) { - if (streaming && (isUnifiedWorker || streamingEngine)) { - throw new UnsupportedOperationException( - String.format( - "%s does not currently support %s when using %s", - DataflowRunner.class.getSimpleName(), - MapState.class.getSimpleName(), - isUnifiedWorker ? "streaming on unified worker" : "streaming engine")); - } + if (DoFnSignatures.usesMapState(fn) && streaming && isUnifiedWorker) { + throw new UnsupportedOperationException( + String.format( + "%s does not currently support %s when using streaming on unified worker", + DataflowRunner.class.getSimpleName(), MapState.class.getSimpleName())); } if (DoFnSignatures.usesBundleFinalizer(fn) && !isUnifiedWorker) { throw new UnsupportedOperationException( diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java index 55bfc44ee62b..cf1066e41d25 100644 --- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java +++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java @@ -131,8 +131,6 @@ import org.apache.beam.sdk.runners.PTransformOverrideFactory.ReplacementOutput; import org.apache.beam.sdk.runners.TransformHierarchy; import org.apache.beam.sdk.runners.TransformHierarchy.Node; -import org.apache.beam.sdk.state.MapState; -import org.apache.beam.sdk.state.SetState; import org.apache.beam.sdk.state.StateSpec; import org.apache.beam.sdk.state.StateSpecs; import org.apache.beam.sdk.state.ValueState; @@ -1880,63 +1878,6 @@ public void testSettingConflictingEnableAndDisableExperimentsThrowsException() t } } - private void verifyMapStateUnsupported(PipelineOptions options) throws Exception { - Pipeline p = Pipeline.create(options); - p.apply(Create.of(KV.of(13, 42))) - .apply( - ParDo.of( - new DoFn, Void>() { - - @StateId("fizzle") - private final StateSpec> voidState = StateSpecs.map(); - - @ProcessElement - public void process() {} - })); - - thrown.expectMessage("MapState"); - thrown.expect(UnsupportedOperationException.class); - p.run(); - } - - @Test - public void testMapStateUnsupportedStreamingEngine() throws Exception { - PipelineOptions options = buildPipelineOptions(); - ExperimentalOptions.addExperiment( - options.as(ExperimentalOptions.class), GcpOptions.STREAMING_ENGINE_EXPERIMENT); - options.as(DataflowPipelineOptions.class).setStreaming(true); - - verifyMapStateUnsupported(options); - } - - private void verifySetStateUnsupported(PipelineOptions options) throws Exception { - Pipeline p = Pipeline.create(options); - p.apply(Create.of(KV.of(13, 42))) - .apply( - ParDo.of( - new DoFn, Void>() { - - @StateId("fizzle") - private final StateSpec> voidState = StateSpecs.set(); - - @ProcessElement - public void process() {} - })); - - thrown.expectMessage("SetState"); - thrown.expect(UnsupportedOperationException.class); - p.run(); - } - - @Test - public void testSetStateUnsupportedStreamingEngine() throws Exception { - PipelineOptions options = buildPipelineOptions(); - ExperimentalOptions.addExperiment( - options.as(ExperimentalOptions.class), GcpOptions.STREAMING_ENGINE_EXPERIMENT); - options.as(DataflowPipelineOptions.class).setStreaming(true); - verifySetStateUnsupported(options); - } - /** Records all the composite transforms visited within the Pipeline. */ private static class CompositeTransformRecorder extends PipelineVisitor.Defaults { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java index 59819db88a07..0e46e7e4687e 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java @@ -324,7 +324,10 @@ public static StreamingDataflowWorker fromOptions(DataflowWorkerHarnessOptions o BoundedQueueExecutor workExecutor = createWorkUnitExecutor(options); AtomicInteger maxWorkItemCommitBytes = new AtomicInteger(Integer.MAX_VALUE); WindmillStateCache windmillStateCache = - WindmillStateCache.ofSizeMbs(options.getWorkerCacheMb()); + WindmillStateCache.builder() + .setSizeMb(options.getWorkerCacheMb()) + .setSupportMapViaMultimap(options.isEnableStreamingEngine()) + .build(); Function executorSupplier = threadName -> Executors.newSingleThreadScheduledExecutor( @@ -478,7 +481,11 @@ static StreamingDataflowWorker forTesting( ConcurrentMap stageInfo = new ConcurrentHashMap<>(); AtomicInteger maxWorkItemCommitBytes = new AtomicInteger(maxWorkItemCommitBytesOverrides); BoundedQueueExecutor workExecutor = createWorkUnitExecutor(options); - WindmillStateCache stateCache = WindmillStateCache.ofSizeMbs(options.getWorkerCacheMb()); + WindmillStateCache stateCache = + WindmillStateCache.builder() + .setSizeMb(options.getWorkerCacheMb()) + .setSupportMapViaMultimap(options.isEnableStreamingEngine()) + .build(); ComputationConfig.Fetcher configFetcher = options.isEnableStreamingEngine() ? StreamingEngineComputationConfigFetcher.forTesting( diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/AbstractWindmillMap.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/AbstractWindmillMap.java new file mode 100644 index 000000000000..e144d5cf8c3f --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/AbstractWindmillMap.java @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.state; + +import org.apache.beam.sdk.state.MapState; + +public abstract class AbstractWindmillMap extends SimpleWindmillState + implements MapState {} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/CachingStateTable.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/CachingStateTable.java index bcaf8bf21a2d..c026aac4f96b 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/CachingStateTable.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/CachingStateTable.java @@ -24,17 +24,9 @@ import org.apache.beam.runners.core.StateTable; import org.apache.beam.runners.core.StateTag; import org.apache.beam.runners.core.StateTags; +import org.apache.beam.sdk.coders.BooleanCoder; import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.state.BagState; -import org.apache.beam.sdk.state.CombiningState; -import org.apache.beam.sdk.state.MapState; -import org.apache.beam.sdk.state.MultimapState; -import org.apache.beam.sdk.state.OrderedListState; -import org.apache.beam.sdk.state.SetState; -import org.apache.beam.sdk.state.State; -import org.apache.beam.sdk.state.StateContext; -import org.apache.beam.sdk.state.ValueState; -import org.apache.beam.sdk.state.WatermarkHoldState; +import org.apache.beam.sdk.state.*; import org.apache.beam.sdk.transforms.Combine; import org.apache.beam.sdk.transforms.CombineWithContext; import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; @@ -50,6 +42,7 @@ final class CachingStateTable extends StateTable { private final Supplier scopedReadStateSupplier; private final @Nullable StateTable derivedStateTable; private final boolean isNewKey; + private final boolean mapStateViaMultimapState; private CachingStateTable(Builder builder) { this.stateFamily = builder.stateFamily; @@ -59,6 +52,7 @@ private CachingStateTable(Builder builder) { this.isNewKey = builder.isNewKey; this.scopedReadStateSupplier = builder.scopedReadStateSupplier; this.derivedStateTable = builder.derivedStateTable; + this.mapStateViaMultimapState = builder.mapStateViaMultimapState; if (this.isSystemTable) { Preconditions.checkState(derivedStateTable == null); @@ -103,30 +97,39 @@ public BagState bindBag(StateTag> address, Coder elemCoder @Override public SetState bindSet(StateTag> spec, Coder elemCoder) { + StateTag> internalMapAddress = StateTags.convertToMapTagInternal(spec); WindmillSet result = - new WindmillSet<>(namespace, spec, stateFamily, elemCoder, cache, isNewKey); + new WindmillSet<>(bindMap(internalMapAddress, elemCoder, BooleanCoder.of())); result.initializeForWorkItem(reader, scopedReadStateSupplier); return result; } @Override - public MapState bindMap( + public AbstractWindmillMap bindMap( StateTag> spec, Coder keyCoder, Coder valueCoder) { - WindmillMap result = - cache - .get(namespace, spec) - .map(mapState -> (WindmillMap) mapState) - .orElseGet( - () -> - new WindmillMap<>( - namespace, spec, stateFamily, keyCoder, valueCoder, isNewKey)); - + AbstractWindmillMap result; + if (mapStateViaMultimapState) { + StateTag> internalMultimapAddress = + StateTags.convertToMultiMapTagInternal(spec); + result = + new WindmillMapViaMultimap<>( + bindMultimap(internalMultimapAddress, keyCoder, valueCoder)); + } else { + result = + cache + .get(namespace, spec) + .map(mapState -> (AbstractWindmillMap) mapState) + .orElseGet( + () -> + new WindmillMap<>( + namespace, spec, stateFamily, keyCoder, valueCoder, isNewKey)); + } result.initializeForWorkItem(reader, scopedReadStateSupplier); return result; } @Override - public MultimapState bindMultimap( + public WindmillMultimap bindMultimap( StateTag> spec, Coder keyCoder, Coder valueCoder) { @@ -246,6 +249,7 @@ static class Builder { private final boolean isNewKey; private boolean isSystemTable; private @Nullable StateTable derivedStateTable; + private boolean mapStateViaMultimapState = false; private Builder( String stateFamily, @@ -268,6 +272,11 @@ Builder withDerivedState(StateTable derivedStateTable) { return this; } + Builder withMapStateViaMultimapState() { + this.mapStateViaMultimapState = true; + return this; + } + CachingStateTable build() { return new CachingStateTable(this); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMap.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMap.java index 9f027af0a870..aed03f33e6d6 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMap.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMap.java @@ -21,10 +21,7 @@ import java.io.Closeable; import java.io.IOException; -import java.util.AbstractMap; -import java.util.Collections; -import java.util.Map; -import java.util.Set; +import java.util.*; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; import java.util.function.Function; @@ -40,6 +37,8 @@ import org.apache.beam.sdk.util.Weighted; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets; @@ -51,7 +50,7 @@ @SuppressWarnings({ "nullness" // TODO(https://github.com/apache/beam/issues/20497) }) -public class WindmillMap extends SimpleWindmillState implements MapState { +public class WindmillMap extends AbstractWindmillMap { private final StateNamespace namespace; private final StateTag> address; private final ByteString stateKeyPrefix; @@ -327,7 +326,7 @@ private class WindmillMapEntriesReadableState @Override public Iterable> read() { if (complete) { - return Iterables.unmodifiableIterable(cachedValues.entrySet()); + return ImmutableMap.copyOf(cachedValues).entrySet(); } Future>> persistedData = getFuture(); try (Closeable scope = scopedReadState()) { @@ -352,20 +351,22 @@ public Iterable> read() { cachedValues.putIfAbsent(e.getKey(), e.getValue()); }); complete = true; - return Iterables.unmodifiableIterable(cachedValues.entrySet()); + return ImmutableMap.copyOf(cachedValues).entrySet(); } else { + ImmutableMap cachedCopy = ImmutableMap.copyOf(cachedValues); + ImmutableSet removalCopy = ImmutableSet.copyOf(localRemovals); // This means that the result might be too large to cache, so don't add it to the // local cache. Instead merge the iterables, giving priority to any local additions - // (represented in cachedValued and localRemovals) that may not have been committed + // (represented in cachedCopy and removalCopy) that may not have been committed // yet. return Iterables.unmodifiableIterable( Iterables.concat( - cachedValues.entrySet(), + cachedCopy.entrySet(), Iterables.filter( transformedData, e -> - !cachedValues.containsKey(e.getKey()) - && !localRemovals.contains(e.getKey())))); + !cachedCopy.containsKey(e.getKey()) + && !removalCopy.contains(e.getKey())))); } } catch (InterruptedException | ExecutionException | IOException e) { @@ -428,7 +429,6 @@ public WindmillMapReadResultReadableState(K key, @Nullable V defaultValue) { negativeCache.add(key); return defaultValue; } - // TODO: Don't do this if it was already in cache. cachedValues.put(key, persistedValue); return persistedValue; } catch (InterruptedException | ExecutionException | IOException e) { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMapViaMultimap.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMapViaMultimap.java new file mode 100644 index 000000000000..0ee508a53baf --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMapViaMultimap.java @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.state; + +import java.io.Closeable; +import java.io.IOException; +import java.util.Iterator; +import java.util.Map; +import java.util.function.Function; +import javax.annotation.Nullable; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.sdk.state.ReadableState; +import org.apache.beam.sdk.state.ReadableStates; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Supplier; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterators; + +public class WindmillMapViaMultimap extends AbstractWindmillMap { + final WindmillMultimap multimap; + + WindmillMapViaMultimap(WindmillMultimap multimap) { + this.multimap = multimap; + } + + @Override + protected Windmill.WorkItemCommitRequest persistDirectly(WindmillStateCache.ForKeyAndFamily cache) + throws IOException { + return multimap.persistDirectly(cache); + } + + @Override + void initializeForWorkItem( + WindmillStateReader reader, Supplier scopedReadStateSupplier) { + super.initializeForWorkItem(reader, scopedReadStateSupplier); + multimap.initializeForWorkItem(reader, scopedReadStateSupplier); + } + + @Override + void cleanupAfterWorkItem() { + super.cleanupAfterWorkItem(); + multimap.cleanupAfterWorkItem(); + } + + @Override + public void put(KeyT key, ValueT value) { + multimap.remove(key); + multimap.put(key, value); + } + + @Override + public ReadableState computeIfAbsent( + KeyT key, Function mappingFunction) { + // Note that computeIfAbsent comments indicate that the read is lazy but this matches the + // existing eager + // behavior of WindmillMap. + Iterable existingValues = multimap.get(key).read(); + if (Iterables.isEmpty(existingValues)) { + ValueT inserted = mappingFunction.apply(key); + multimap.put(key, inserted); + return ReadableStates.immediate(inserted); + } else { + return ReadableStates.immediate(Iterables.getOnlyElement(existingValues)); + } + } + + @Override + public void remove(KeyT key) { + multimap.remove(key); + } + + private static class SingleValueIterableAdaptor implements ReadableState { + final ReadableState> wrapped; + final @Nullable T defaultValue; + + SingleValueIterableAdaptor(ReadableState> wrapped, @Nullable T defaultValue) { + this.wrapped = wrapped; + this.defaultValue = defaultValue; + } + + @Override + public T read() { + Iterator iterator = wrapped.read().iterator(); + if (!iterator.hasNext()) { + return null; + } + return Iterators.getOnlyElement(iterator); + } + + @Override + public ReadableState readLater() { + wrapped.readLater(); + return this; + } + } + + @Override + public ReadableState get(KeyT key) { + return getOrDefault(key, null); + } + + @Override + public ReadableState getOrDefault(KeyT key, @Nullable ValueT defaultValue) { + return new SingleValueIterableAdaptor<>(multimap.get(key), defaultValue); + } + + @Override + public ReadableState> keys() { + return multimap.keys(); + } + + private static class RemoveKeyAdaptor implements ReadableState> { + final ReadableState>> wrapped; + + RemoveKeyAdaptor(ReadableState>> wrapped) { + this.wrapped = wrapped; + } + + @Override + public Iterable read() { + return Iterables.transform(wrapped.read(), Map.Entry::getValue); + } + + @Override + public ReadableState> readLater() { + wrapped.readLater(); + return this; + } + } + + @Override + public ReadableState> values() { + return new RemoveKeyAdaptor<>(multimap.entries()); + } + + @Override + public ReadableState>> entries() { + return multimap.entries(); + } + + @Override + public ReadableState isEmpty() { + return multimap.isEmpty(); + } + + @Override + public void clear() { + multimap.clear(); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMultimap.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMultimap.java index 75f33e69e0be..19c79a497d4c 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMultimap.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMultimap.java @@ -216,8 +216,8 @@ public void remove(K key) { if (keyState == null || keyState.existence == KeyExistence.KNOWN_NONEXISTENT) { return; } - if (keyState.valuesCached && keyState.valuesSize == 0) { - // no data in windmill, deleting from local cache is sufficient. + if (keyState.valuesCached && keyState.valuesSize == 0 && !keyState.removedLocally) { + // no data in windmill and no need to keep state, deleting from local cache is sufficient. keyStateMap.remove(structuralKey); } else { // there may be data in windmill that need to be removed. diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillSet.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillSet.java index 4afb879e722e..ee7e6862c7a1 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillSet.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillSet.java @@ -20,13 +20,7 @@ import java.io.Closeable; import java.io.IOException; import java.util.Optional; -import org.apache.beam.runners.core.StateNamespace; -import org.apache.beam.runners.core.StateTag; -import org.apache.beam.runners.core.StateTags; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; -import org.apache.beam.sdk.coders.BooleanCoder; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.state.MapState; import org.apache.beam.sdk.state.ReadableState; import org.apache.beam.sdk.state.SetState; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Supplier; @@ -35,30 +29,10 @@ import org.checkerframework.checker.nullness.qual.UnknownKeyFor; public class WindmillSet extends SimpleWindmillState implements SetState { - private final WindmillMap windmillMap; - - WindmillSet( - StateNamespace namespace, - StateTag> address, - String stateFamily, - Coder keyCoder, - WindmillStateCache.ForKeyAndFamily cache, - boolean isNewKey) { - StateTag> internalMapAddress = StateTags.convertToMapTagInternal(address); - - this.windmillMap = - cache - .get(namespace, internalMapAddress) - .map(map -> (WindmillMap) map) - .orElseGet( - () -> - new WindmillMap<>( - namespace, - internalMapAddress, - stateFamily, - keyCoder, - BooleanCoder.of(), - isNewKey)); + private final AbstractWindmillMap windmillMap; + + WindmillSet(AbstractWindmillMap windmillMap) { + this.windmillMap = windmillMap; } @Override @@ -117,11 +91,13 @@ public void clear() { @Override void initializeForWorkItem( WindmillStateReader reader, Supplier scopedReadStateSupplier) { + super.initializeForWorkItem(reader, scopedReadStateSupplier); windmillMap.initializeForWorkItem(reader, scopedReadStateSupplier); } @Override void cleanupAfterWorkItem() { + super.cleanupAfterWorkItem(); windmillMap.cleanupAfterWorkItem(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCache.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCache.java index c6c49134bcb5..64eb9dd941bb 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCache.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCache.java @@ -17,6 +17,7 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.state; +import com.google.auto.value.AutoBuilder; import java.io.IOException; import java.io.PrintWriter; import java.util.HashMap; @@ -29,9 +30,7 @@ import org.apache.beam.runners.core.StateNamespace; import org.apache.beam.runners.core.StateTag; import org.apache.beam.runners.core.StateTags; -import org.apache.beam.runners.dataflow.worker.StreamingDataflowWorker; -import org.apache.beam.runners.dataflow.worker.Weighers; -import org.apache.beam.runners.dataflow.worker.WindmillComputationKey; +import org.apache.beam.runners.dataflow.worker.*; import org.apache.beam.runners.dataflow.worker.status.BaseStatusServlet; import org.apache.beam.runners.dataflow.worker.status.StatusDataProvider; import org.apache.beam.runners.dataflow.worker.streaming.ShardedKey; @@ -76,26 +75,33 @@ public class WindmillStateCache implements StatusDataProvider { // entries inaccessible. They will be evicted through normal cache operation. private final ConcurrentMap keyIndex; private final long workerCacheBytes; // Copy workerCacheMb and convert to bytes. + private final boolean supportMapViaMultimap; - private WindmillStateCache( - long workerCacheMb, - ConcurrentMap keyIndex, - Cache stateCache) { - this.workerCacheBytes = workerCacheMb * MEGABYTES; - this.stateCache = stateCache; - this.keyIndex = keyIndex; - } - - public static WindmillStateCache ofSizeMbs(long workerCacheMb) { - return new WindmillStateCache( - workerCacheMb, - new MapMaker().weakValues().concurrencyLevel(STATE_CACHE_CONCURRENCY_LEVEL).makeMap(), + WindmillStateCache(long sizeMb, boolean supportMapViaMultimap) { + this.workerCacheBytes = sizeMb * MEGABYTES; + this.stateCache = CacheBuilder.newBuilder() - .maximumWeight(workerCacheMb * MEGABYTES) + .maximumWeight(workerCacheBytes) .recordStats() .weigher(Weighers.weightedKeysAndValues()) .concurrencyLevel(STATE_CACHE_CONCURRENCY_LEVEL) - .build()); + .build(); + this.keyIndex = + new MapMaker().weakValues().concurrencyLevel(STATE_CACHE_CONCURRENCY_LEVEL).makeMap(); + this.supportMapViaMultimap = supportMapViaMultimap; + } + + @AutoBuilder(ofClass = WindmillStateCache.class) + public interface Builder { + Builder setSizeMb(long sizeMb); + + Builder setSupportMapViaMultimap(boolean supportMapViaMultimap); + + WindmillStateCache build(); + } + + public static Builder builder() { + return new AutoBuilder_WindmillStateCache_Builder().setSupportMapViaMultimap(false); } private EntryStats calculateEntryStats() { @@ -399,6 +405,10 @@ public String getStateFamily() { return stateFamily; } + public boolean supportMapStateViaMultimapState() { + return supportMapViaMultimap; + } + public Optional get(StateNamespace namespace, StateTag address) { @SuppressWarnings("nullness") // the mapping function for localCache.computeIfAbsent (i.e stateCache.getIfPresent) is diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternals.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternals.java index c900228e86b0..f757db991fa7 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternals.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternals.java @@ -66,13 +66,13 @@ public WindmillStateInternals( this.key = key; this.cache = cache; this.scopedReadStateSupplier = scopedReadStateSupplier; - this.workItemDerivedState = - CachingStateTable.builder(stateFamily, reader, cache, isNewKey, scopedReadStateSupplier) - .build(); - this.workItemState = - CachingStateTable.builder(stateFamily, reader, cache, isNewKey, scopedReadStateSupplier) - .withDerivedState(workItemDerivedState) - .build(); + CachingStateTable.Builder builder = + CachingStateTable.builder(stateFamily, reader, cache, isNewKey, scopedReadStateSupplier); + if (cache.supportMapStateViaMultimapState()) { + builder = builder.withMapStateViaMultimapState(); + } + this.workItemDerivedState = builder.build(); + this.workItemState = builder.withDerivedState(workItemDerivedState).build(); } @Override diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java index 2193f20f3fe3..6c46bda5acfe 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java @@ -112,7 +112,10 @@ public void setUp() { COMPUTATION_ID, new ReaderCache(Duration.standardMinutes(1), Executors.newCachedThreadPool()), stateNameMap, - WindmillStateCache.ofSizeMbs(options.getWorkerCacheMb()).forComputation("comp"), + WindmillStateCache.builder() + .setSizeMb(options.getWorkerCacheMb()) + .build() + .forComputation("comp"), StreamingStepMetricsContainer.createRegistry(), new DataflowExecutionStateTracker( ExecutionStateSampler.newForTest(), diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateTestUtils.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateTestUtils.java index 17da531d4525..8708b9f502d3 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateTestUtils.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateTestUtils.java @@ -66,8 +66,8 @@ public static void assertNoReference( boolean accessible = f.isAccessible(); try { - f.setAccessible(true); path.add(thisClazz.getName() + "#" + f.getName()); + f.setAccessible(true); assertNoReference(f.get(obj), clazz, path, visited); } finally { path.remove(path.size() - 1); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java index 9f97c9835ddc..5d8ebd53400c 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java @@ -964,7 +964,10 @@ public void testFailedWorkItemsAbort() throws Exception { COMPUTATION_ID, new ReaderCache(Duration.standardMinutes(1), Runnable::run), /*stateNameMap=*/ ImmutableMap.of(), - WindmillStateCache.ofSizeMbs(options.getWorkerCacheMb()).forComputation(COMPUTATION_ID), + WindmillStateCache.builder() + .setSizeMb(options.getWorkerCacheMb()) + .build() + .forComputation(COMPUTATION_ID), StreamingStepMetricsContainer.createRegistry(), new DataflowExecutionStateTracker( ExecutionStateSampler.newForTest(), diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCacheTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCacheTest.java index 446a34f73dec..ce8da106b0ce 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCacheTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCacheTest.java @@ -148,7 +148,7 @@ private static WindmillComputationKey computationKey( @Before public void setUp() { options = PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class); - cache = WindmillStateCache.ofSizeMbs(400); + cache = WindmillStateCache.builder().setSizeMb(400).build(); assertEquals(0, cache.getWeight()); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternalsTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternalsTest.java index a53240d64530..33e47623cd0e 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternalsTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternalsTest.java @@ -20,11 +20,7 @@ import static org.apache.beam.runners.dataflow.worker.DataflowMatchers.ByteStringMatcher.byteStringEq; import static org.apache.beam.sdk.testing.SystemNanoTimeSleeper.sleepMillis; import static org.hamcrest.MatcherAssert.assertThat; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.Matchers.eq; @@ -130,7 +126,9 @@ public class WindmillStateInternalsTest { @Mock private WindmillStateReader mockReader; private WindmillStateInternals underTest; private WindmillStateInternals underTestNewKey; + private WindmillStateInternals underTestMapViaMultimap; private WindmillStateCache cache; + private WindmillStateCache cacheViaMultimap; @Mock private Supplier readStateSupplier; private static ByteString key(StateNamespace namespace, String addrId) { @@ -206,7 +204,12 @@ private static void assertTagMultimapUpdates( public void setUp() { MockitoAnnotations.initMocks(this); options = PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class); - cache = WindmillStateCache.ofSizeMbs(options.getWorkerCacheMb()); + cache = WindmillStateCache.builder().setSizeMb(options.getWorkerCacheMb()).build(); + cacheViaMultimap = + WindmillStateCache.builder() + .setSizeMb(options.getWorkerCacheMb()) + .setSupportMapViaMultimap(true) + .build(); resetUnderTest(); } @@ -242,6 +245,21 @@ public void resetUnderTest() { workToken) .forFamily(STATE_FAMILY), readStateSupplier); + underTestMapViaMultimap = + new WindmillStateInternals( + "dummyNewKey", + STATE_FAMILY, + mockReader, + false, + cacheViaMultimap + .forComputation("comp") + .forKey( + WindmillComputationKey.create( + "comp", ByteString.copyFrom("dummyNewKey", Charsets.UTF_8), 123), + 17L, + workToken) + .forFamily(STATE_FAMILY), + readStateSupplier); } @After @@ -249,6 +267,7 @@ public void tearDown() throws Exception { // Make sure no WindmillStateReader (a per-WorkItem object) escapes into the cache // (a global object). WindmillStateTestUtils.assertNoReference(cache, WindmillStateReader.class); + WindmillStateTestUtils.assertNoReference(cacheViaMultimap, WindmillStateReader.class); } private void waitAndSet(final SettableFuture future, final T value, final long millis) { @@ -741,6 +760,38 @@ public void testMultimapGet() { assertThat(result.read(), Matchers.containsInAnyOrder(1, 2, 3)); } + @Test + public void testMapViaMultimapGet() { + final String tag = "map"; + StateTag> addr = + StateTags.map(tag, ByteArrayCoder.of(), VarIntCoder.of()); + MapState mapViaMultiMapState = underTestMapViaMultimap.state(NAMESPACE, addr); + + final byte[] key1 = "key1".getBytes(StandardCharsets.UTF_8); + final byte[] key2 = "key2".getBytes(StandardCharsets.UTF_8); + SettableFuture> future1 = SettableFuture.create(); + when(mockReader.multimapFetchSingleEntryFuture( + encodeWithCoder(key1, ByteArrayCoder.of()), + key(NAMESPACE, tag), + STATE_FAMILY, + VarIntCoder.of())) + .thenReturn(future1); + SettableFuture> future2 = SettableFuture.create(); + when(mockReader.multimapFetchSingleEntryFuture( + encodeWithCoder(key2, ByteArrayCoder.of()), + key(NAMESPACE, tag), + STATE_FAMILY, + VarIntCoder.of())) + .thenReturn(future2); + + ReadableState result1 = mapViaMultiMapState.get(dup(key1)).readLater(); + ReadableState result2 = mapViaMultiMapState.get(dup(key2)).readLater(); + waitAndSet(future1, Collections.singletonList(1), 30); + waitAndSet(future2, Collections.emptyList(), 1); + assertEquals(Integer.valueOf(1), result1.read()); + assertNull(result2.read()); + } + @Test public void testMultimapPutAndGet() { final String tag = "multimap"; @@ -761,6 +812,41 @@ public void testMultimapPutAndGet() { ReadableState> result = multimapState.get(dup(key)).readLater(); waitAndSet(future, Arrays.asList(1, 2, 3), 30); assertThat(result.read(), Matchers.containsInAnyOrder(1, 1, 2, 3)); + + multimapState.remove(key); + multimapState.put(key, 4); + multimapState.remove(key); + multimapState.put(key, 5); + assertThat(result.read(), Matchers.containsInAnyOrder(5)); + multimapState.clear(); + assertThat(multimapState.get(key).read(), Matchers.emptyIterable()); + } + + @Test + public void testMapViaMultimapPutAndGet() { + final String tag = "map"; + StateTag> addr = + StateTags.map(tag, ByteArrayCoder.of(), VarIntCoder.of()); + MapState mapViaMultiMapState = underTestMapViaMultimap.state(NAMESPACE, addr); + + final byte[] key = "key".getBytes(StandardCharsets.UTF_8); + SettableFuture> future = SettableFuture.create(); + when(mockReader.multimapFetchSingleEntryFuture( + encodeWithCoder(key, ByteArrayCoder.of()), + key(NAMESPACE, tag), + STATE_FAMILY, + VarIntCoder.of())) + .thenReturn(future); + + mapViaMultiMapState.put(key, 1); + ReadableState result = mapViaMultiMapState.get(dup(key)).readLater(); + waitAndSet(future, Collections.singletonList(2), 30); + assertEquals(Integer.valueOf(1), result.read()); + + mapViaMultiMapState.put(key, 3); + assertEquals(Integer.valueOf(3), mapViaMultiMapState.get(key).read()); + mapViaMultiMapState.clear(); + assertNull(mapViaMultiMapState.get(key).read()); } @Test @@ -791,6 +877,33 @@ public void testMultimapRemoveAndGet() { assertThat(result2.read(), Matchers.emptyIterable()); } + @Test + public void testMapViaMultimapRemoveAndGet() { + final String tag = "map"; + StateTag> addr = + StateTags.map(tag, ByteArrayCoder.of(), VarIntCoder.of()); + MapState mapViaMultiMapState = underTestMapViaMultimap.state(NAMESPACE, addr); + + final byte[] key = "key".getBytes(StandardCharsets.UTF_8); + SettableFuture> future = SettableFuture.create(); + when(mockReader.multimapFetchSingleEntryFuture( + encodeWithCoder(key, ByteArrayCoder.of()), + key(NAMESPACE, tag), + STATE_FAMILY, + VarIntCoder.of())) + .thenReturn(future); + + ReadableState result1 = mapViaMultiMapState.get(key).readLater(); + ReadableState result2 = mapViaMultiMapState.get(dup(key)).readLater(); + waitAndSet(future, Collections.singletonList(1), 30); + + assertEquals(Integer.valueOf(1), result1.read()); + + mapViaMultiMapState.remove(key); + assertNull(mapViaMultiMapState.get(dup(key)).read()); + assertNull(result2.read()); + } + @Test public void testMultimapRemoveThenPut() { final String tag = "multimap"; @@ -1030,6 +1143,64 @@ true, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of())) assertThat(keys, Matchers.containsInAnyOrder(key1, key2, key3)); } + @Test + public void testMapViaMultimapEntriesAndKeysMergeLocalAddRemoveClear() { + final String tag = "map"; + StateTag> addr = + StateTags.map(tag, ByteArrayCoder.of(), VarIntCoder.of()); + MapState mapState = underTestMapViaMultimap.state(NAMESPACE, addr); + + final byte[] key1 = "key1".getBytes(StandardCharsets.UTF_8); + final byte[] key2 = "key2".getBytes(StandardCharsets.UTF_8); + final byte[] key3 = "key3".getBytes(StandardCharsets.UTF_8); + final byte[] key4 = "key4".getBytes(StandardCharsets.UTF_8); + + SettableFuture>>> entriesFuture = + SettableFuture.create(); + when(mockReader.multimapFetchAllFuture( + false, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of())) + .thenReturn(entriesFuture); + SettableFuture>>> keysFuture = + SettableFuture.create(); + when(mockReader.multimapFetchAllFuture( + true, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of())) + .thenReturn(keysFuture); + + ReadableState>> entriesResult = + mapState.entries().readLater(); + ReadableState> keysResult = mapState.keys().readLater(); + waitAndSet(entriesFuture, Arrays.asList(multimapEntry(key1, 3), multimapEntry(key2, 4)), 30); + waitAndSet(keysFuture, Arrays.asList(multimapEntry(key1), multimapEntry(key2)), 30); + + mapState.put(key1, 7); + mapState.put(dup(key3), 8); + mapState.put(key4, 1); + mapState.remove(key4); + + Iterable> entries = entriesResult.read(); + assertEquals(3, Iterables.size(entries)); + assertThat( + entries, + Matchers.containsInAnyOrder( + multimapEntryMatcher(key1, 7), + multimapEntryMatcher(key2, 4), + multimapEntryMatcher(key3, 8))); + + Iterable keys = keysResult.read(); + assertEquals(3, Iterables.size(keys)); + assertThat(keys, Matchers.containsInAnyOrder(key1, key2, key3)); + assertFalse(mapState.isEmpty().read()); + + mapState.clear(); + assertTrue(mapState.isEmpty().read()); + assertTrue(Iterables.isEmpty(mapState.keys().read())); + assertTrue(Iterables.isEmpty(mapState.entries().read())); + + // Previously read iterable should still have the same result. + assertEquals(3, Iterables.size(keys)); + assertThat(keys, Matchers.containsInAnyOrder(key1, key2, key3)); + } + @Test public void testMultimapEntriesAndKeysMergeLocalRemove() { final String tag = "multimap"; @@ -1080,6 +1251,48 @@ true, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of())) assertThat(keys, Matchers.containsInAnyOrder(key2, key3)); } + @Test + public void testMapViaMultimapEntriesAndKeysMergeLocalRemove() { + final String tag = "map"; + StateTag> addr = + StateTags.map(tag, ByteArrayCoder.of(), VarIntCoder.of()); + MapState mapState = underTestMapViaMultimap.state(NAMESPACE, addr); + + final byte[] key1 = "key1".getBytes(StandardCharsets.UTF_8); + final byte[] key2 = "key2".getBytes(StandardCharsets.UTF_8); + final byte[] key3 = "key3".getBytes(StandardCharsets.UTF_8); + + SettableFuture>>> entriesFuture = + SettableFuture.create(); + when(mockReader.multimapFetchAllFuture( + false, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of())) + .thenReturn(entriesFuture); + SettableFuture>>> keysFuture = + SettableFuture.create(); + when(mockReader.multimapFetchAllFuture( + true, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of())) + .thenReturn(keysFuture); + + ReadableState>> entriesResult = + mapState.entries().readLater(); + ReadableState> keysResult = mapState.keys().readLater(); + waitAndSet(entriesFuture, Arrays.asList(multimapEntry(key1, 1), multimapEntry(key2, 2)), 30); + waitAndSet(keysFuture, Arrays.asList(multimapEntry(key1), multimapEntry(key2)), 30); + + mapState.remove(dup(key1)); + mapState.put(key2, 8); + mapState.put(dup(key3), 9); + + Iterable> entries = entriesResult.read(); + assertEquals(2, Iterables.size(entries)); + assertThat( + entries, + Matchers.containsInAnyOrder(multimapEntryMatcher(key2, 8), multimapEntryMatcher(key3, 9))); + + Iterable keys = keysResult.read(); + assertThat(keys, Matchers.containsInAnyOrder(key2, key3)); + } + @Test public void testMultimapCacheComplete() { final String tag = "multimap"; diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/DispatchedActiveWorkRefresherTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/DispatchedActiveWorkRefresherTest.java index 175c8421ff8e..13019116767c 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/DispatchedActiveWorkRefresherTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/DispatchedActiveWorkRefresherTest.java @@ -207,7 +207,7 @@ public void testInvalidateStuckCommits() throws InterruptedException { int stuckCommitDurationMillis = 100; Table computations = HashBasedTable.create(); - WindmillStateCache stateCache = WindmillStateCache.ofSizeMbs(100); + WindmillStateCache stateCache = WindmillStateCache.builder().setSizeMb(100).build(); ByteString key = ByteString.EMPTY; for (int i = 0; i < 5; i++) { WindmillStateCache.ForComputation perComputationStateCache = diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpecs.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpecs.java index 942881522cf0..df5084ad0921 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpecs.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpecs.java @@ -377,6 +377,25 @@ public static StateSpec> convertToMapSpecInternal } } + /** + * For internal use only; no backwards-compatibility guarantees. + * + *

Convert a set state spec to a map-state spec. + */ + @Internal + public static StateSpec> convertToMultimapSpecInternal( + StateSpec> spec) { + if (spec instanceof MapStateSpec) { + // Checked above; conversion to a map spec depends on the provided spec being one of those + // created via the factory methods in this class. + @SuppressWarnings("unchecked") + MapStateSpec typedSpec = (MapStateSpec) spec; + return typedSpec.asMultimapSpec(); + } else { + throw new IllegalArgumentException("Unexpected StateSpec " + spec); + } + } + /** * A specification for a state cell holding a settable value of type {@code T}. * @@ -768,6 +787,10 @@ public boolean equals(@Nullable Object obj) { public int hashCode() { return Objects.hash(getClass(), keyCoder, valueCoder); } + + private MultimapStateSpec asMultimapSpec() { + return new MultimapStateSpec<>(this.keyCoder, this.valueCoder); + } } private static class MultimapStateSpec implements StateSpec> { diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java index 89dcafbdf94f..fb2321328b32 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java @@ -2709,19 +2709,26 @@ public void processElement( @StateId(countStateId) CombiningState count, OutputReceiver> r) { KV value = element.getValue(); - ReadableState>> entriesView = state.entries(); state.put(value.getKey(), value.getValue()); count.add(1); + + @Nullable Integer max = state.get("max").read(); + state.put("max", Math.max(max == null ? 0 : max, value.getValue())); if (count.read() >= 4) { - Iterable> iterate = state.entries().read(); + assertEquals(Integer.valueOf(97), state.get("a").read()); + + Iterable> entriesView = state.entries().read(); + Iterable keysView = state.keys().read(); // Make sure that the cached Iterable doesn't change when new elements are added, // but that cached ReadableState views of the state do change. state.put("BadKey", -1); - assertEquals(3, Iterables.size(iterate)); - assertEquals(4, Iterables.size(entriesView.read())); - assertEquals(4, Iterables.size(state.entries().read())); + assertEquals(4, Iterables.size(entriesView)); + assertEquals(4, Iterables.size(keysView)); + assertEquals(5, Iterables.size(state.entries().read())); + assertEquals(5, Iterables.size(state.keys().read())); + assertEquals(Integer.valueOf(97), state.get("max").read()); - for (Map.Entry entry : iterate) { + for (Map.Entry entry : entriesView) { r.output(KV.of(entry.getKey(), entry.getValue())); } } @@ -2732,11 +2739,14 @@ public void processElement( pipeline .apply( Create.of( - KV.of("hello", KV.of("a", 97)), KV.of("hello", KV.of("b", 42)), - KV.of("hello", KV.of("b", 42)), KV.of("hello", KV.of("c", 12)))) + KV.of("hello", KV.of("a", 97)), + KV.of("hello", KV.of("b", 42)), + KV.of("hello", KV.of("b", 42)), + KV.of("hello", KV.of("c", 12)))) .apply(ParDo.of(fn)); - PAssert.that(output).containsInAnyOrder(KV.of("a", 97), KV.of("b", 42), KV.of("c", 12)); + PAssert.that(output) + .containsInAnyOrder(KV.of("a", 97), KV.of("b", 42), KV.of("c", 12), KV.of("max", 97)); pipeline.run(); }