diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java index 06450e60fc05..825c3fb78c7d 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java @@ -57,6 +57,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; import javax.servlet.http.HttpServletRequest; @@ -96,6 +97,7 @@ import org.apache.beam.runners.dataflow.worker.streaming.WeightedBoundedQueue; import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.streaming.Work.State; +import org.apache.beam.runners.dataflow.worker.streaming.WorkHeartbeatResponseProcessor; import org.apache.beam.runners.dataflow.worker.streaming.WorkId; import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; @@ -104,13 +106,16 @@ import org.apache.beam.runners.dataflow.worker.util.common.worker.OutputObjectAndByteCounter; import org.apache.beam.runners.dataflow.worker.util.common.worker.ReadOperation; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatResponse; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.LatencyAttribution; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest; import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub; +import org.apache.beam.runners.dataflow.worker.windmill.appliance.JniWindmillApplianceServer; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamPool; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcWindmillServer; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcWindmillStreamFactory; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateReader; import org.apache.beam.sdk.coders.Coder; @@ -208,7 +213,7 @@ public class StreamingDataflowWorker { private static final Random clientIdGenerator = new Random(); final WindmillStateCache stateCache; // Maps from computation ids to per-computation state. - private final ConcurrentMap computationMap = new ConcurrentHashMap<>(); + private final ConcurrentMap computationMap; private final WeightedBoundedQueue commitQueue = WeightedBoundedQueue.create( MAX_COMMIT_QUEUE_BYTES, commit -> Math.min(MAX_COMMIT_QUEUE_BYTES, commit.getSize())); @@ -280,8 +285,7 @@ public class StreamingDataflowWorker { // Periodic sender of debug information to the debug capture service. private final DebugCapture.@Nullable Manager debugCaptureManager; // Collection of ScheduledExecutorServices that are running periodic functions. - private ArrayList scheduledExecutors = - new ArrayList(); + private final ArrayList scheduledExecutors = new ArrayList<>(); private int retryLocallyDelayMs = 10000; // Periodically fires a global config request to dataflow service. Only used when windmill service // is enabled. @@ -292,6 +296,9 @@ public class StreamingDataflowWorker { @VisibleForTesting StreamingDataflowWorker( + WindmillServerStub windmillServer, + long clientId, + ConcurrentMap computationMap, List mapTasks, DataflowMapTaskExecutorFactory mapTaskExecutorFactory, WorkUnitClient workUnitClient, @@ -299,13 +306,13 @@ public class StreamingDataflowWorker { boolean publishCounters, HotKeyLogger hotKeyLogger, Supplier clock, - Function executorSupplier) - throws IOException { + Function executorSupplier) { this.stateCache = new WindmillStateCache(options.getWorkerCacheMb()); this.readerCache = new ReaderCache( Duration.standardSeconds(options.getReaderCacheTimeoutSec()), Executors.newCachedThreadPool()); + this.computationMap = computationMap; this.mapTaskExecutorFactory = mapTaskExecutorFactory; this.workUnitClient = workUnitClient; this.options = options; @@ -429,8 +436,8 @@ public void run() { commitThreads = commitThreadsBuilder.build(); this.publishCounters = publishCounters; - this.windmillServer = options.getWindmillServerStub(); - this.windmillServer.setProcessHeartbeatResponses(this::handleHeartbeatResponses); + this.clientId = clientId; + this.windmillServer = windmillServer; this.metricTrackingWindmillServer = MetricTrackingWindmillServerStub.builder(windmillServer, memoryMonitor) .setUseStreamingRequests(windmillServiceEnabled) @@ -438,7 +445,6 @@ public void run() { .setNumGetDataStreams(options.getWindmillGetDataStreamCount()) .build(); this.sideInputStateFetcher = new SideInputStateFetcher(metricTrackingWindmillServer, options); - this.clientId = clientIdGenerator.nextLong(); for (MapTask mapTask : mapTasks) { addComputation(mapTask.getSystemName(), mapTask, ImmutableMap.of()); @@ -456,6 +462,44 @@ public void run() { LOG.debug("maxWorkItemCommitBytes: {}", maxWorkItemCommitBytes); } + private static WindmillServerStub createWindmillServerStub( + StreamingDataflowWorkerOptions options, + long clientId, + Consumer> processHeartbeatResponses) { + if (options.getWindmillServiceEndpoint() != null + || options.isEnableStreamingEngine() + || options.getLocalWindmillHostport().startsWith("grpc:")) { + try { + Duration maxBackoff = + !options.isEnableStreamingEngine() && options.getLocalWindmillHostport() != null + ? GrpcWindmillServer.LOCALHOST_MAX_BACKOFF + : GrpcWindmillServer.MAX_BACKOFF; + GrpcWindmillStreamFactory windmillStreamFactory = + GrpcWindmillStreamFactory.of( + JobHeader.newBuilder() + .setJobId(options.getJobId()) + .setProjectId(options.getProject()) + .setWorkerId(options.getWorkerId()) + .setClientId(clientId) + .build()) + .setWindmillMessagesBetweenIsReadyChecks( + options.getWindmillMessagesBetweenIsReadyChecks()) + .setMaxBackOffSupplier(() -> maxBackoff) + .setLogEveryNStreamFailures( + options.getWindmillServiceStreamingLogEveryNStreamFailures()) + .setStreamingRpcBatchLimit(options.getWindmillServiceStreamingRpcBatchLimit()) + .build(); + windmillStreamFactory.scheduleHealthChecks( + options.getWindmillServiceStreamingRpcHealthCheckPeriodMs()); + return GrpcWindmillServer.create(options, windmillStreamFactory, processHeartbeatResponses); + } catch (IOException e) { + throw new RuntimeException("Failed to create GrpcWindmillServer: ", e); + } + } else { + return new JniWindmillApplianceServer(options.getLocalWindmillHostport()); + } + } + /** Returns whether an exception was caused by a {@link OutOfMemoryError}. */ private static boolean isOutOfMemoryError(Throwable t) { while (t != null) { @@ -509,10 +553,17 @@ public static void main(String[] args) throws Exception { worker.start(); } - public static StreamingDataflowWorker fromOptions(StreamingDataflowWorkerOptions options) - throws IOException { - + public static StreamingDataflowWorker fromOptions(StreamingDataflowWorkerOptions options) { + ConcurrentMap computationMap = new ConcurrentHashMap<>(); + long clientId = clientIdGenerator.nextLong(); return new StreamingDataflowWorker( + createWindmillServerStub( + options, + clientId, + new WorkHeartbeatResponseProcessor( + computationId -> Optional.ofNullable(computationMap.get(computationId)))), + clientId, + computationMap, Collections.emptyList(), IntrinsicMapTaskExecutorFactory.defaultFactory(), new DataflowWorkUnitClient(options, LOG), @@ -1626,7 +1677,6 @@ private void getConfigFromDataflowService(@Nullable String computation) throws I @SuppressWarnings("FutureReturnValueIgnored") private void schedulePeriodicGlobalConfigRequests() { Preconditions.checkState(windmillServiceEnabled); - if (!windmillServer.isReady()) { // Get the initial global configuration. This will initialize the windmillServer stub. while (true) { @@ -1975,26 +2025,6 @@ private void sendWorkerUpdatesToDataflowService( } } - public void handleHeartbeatResponses(List responses) { - for (ComputationHeartbeatResponse computationHeartbeatResponse : responses) { - // Maps sharding key to (work token, cache token) for work that should be marked failed. - Multimap failedWork = ArrayListMultimap.create(); - for (Windmill.HeartbeatResponse heartbeatResponse : - computationHeartbeatResponse.getHeartbeatResponsesList()) { - if (heartbeatResponse.getFailed()) { - failedWork.put( - heartbeatResponse.getShardingKey(), - WorkId.builder() - .setWorkToken(heartbeatResponse.getWorkToken()) - .setCacheToken(heartbeatResponse.getCacheToken()) - .build()); - } - } - ComputationState state = computationMap.get(computationHeartbeatResponse.getComputationId()); - if (state != null) state.failWork(failedWork); - } - } - /** * Sends a GetData request to Windmill for all sufficiently old active work. * diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/options/StreamingDataflowWorkerOptions.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/options/StreamingDataflowWorkerOptions.java index a75a60af2ba0..9431470a16fb 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/options/StreamingDataflowWorkerOptions.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/options/StreamingDataflowWorkerOptions.java @@ -17,11 +17,7 @@ */ package org.apache.beam.runners.dataflow.worker.options; -import java.io.IOException; import org.apache.beam.runners.dataflow.options.DataflowWorkerHarnessOptions; -import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub; -import org.apache.beam.runners.dataflow.worker.windmill.appliance.JniWindmillApplianceServer; -import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcWindmillServer; import org.apache.beam.sdk.options.Default; import org.apache.beam.sdk.options.DefaultValueFactory; import org.apache.beam.sdk.options.Description; @@ -36,12 +32,6 @@ "nullness" // TODO(https://github.com/apache/beam/issues/20497) }) public interface StreamingDataflowWorkerOptions extends DataflowWorkerHarnessOptions { - @Description("Stub for communicating with Windmill.") - @Default.InstanceFactory(WindmillServerStubFactory.class) - WindmillServerStub getWindmillServerStub(); - - void setWindmillServerStub(WindmillServerStub value); - @Description("Hostport of a co-located Windmill server.") @Default.InstanceFactory(LocalWindmillHostportFactory.class) String getLocalWindmillHostport(); @@ -168,29 +158,6 @@ public String create(PipelineOptions options) { } } - /** - * Factory for creating {@link WindmillServerStub} instances. If {@link setLocalWindmillHostport} - * is set, returns a stub to a local Windmill server, otherwise returns a remote gRPC stub. - */ - public static class WindmillServerStubFactory implements DefaultValueFactory { - @Override - public WindmillServerStub create(PipelineOptions options) { - StreamingDataflowWorkerOptions streamingOptions = - options.as(StreamingDataflowWorkerOptions.class); - if (streamingOptions.getWindmillServiceEndpoint() != null - || streamingOptions.isEnableStreamingEngine() - || streamingOptions.getLocalWindmillHostport().startsWith("grpc:")) { - try { - return GrpcWindmillServer.create(streamingOptions); - } catch (IOException e) { - throw new RuntimeException("Failed to create GrpcWindmillServer: ", e); - } - } else { - return new JniWindmillApplianceServer(streamingOptions.getLocalWindmillHostport()); - } - } - } - /** Factory for setting value of WindmillServiceStreamingRpcBatchLimit based on environment. */ public static class WindmillServiceStreamingRpcBatchLimitFactory implements DefaultValueFactory { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WorkHeartbeatResponseProcessor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WorkHeartbeatResponseProcessor.java new file mode 100644 index 000000000000..341f434cefa4 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WorkHeartbeatResponseProcessor.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.streaming; + +import java.util.List; +import java.util.Optional; +import java.util.function.Consumer; +import java.util.function.Function; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatResponse; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatResponse; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ArrayListMultimap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Multimap; + +/** + * Processes {@link ComputationHeartbeatResponse}(s). Marks {@link Work} that is invalid from + * Streaming Engine backend so that it gets dropped from streaming worker harness processing. + */ +@Internal +public final class WorkHeartbeatResponseProcessor + implements Consumer> { + /** Fetches a {@link ComputationState} for a computationId. */ + private final Function> computationStateFetcher; + + public WorkHeartbeatResponseProcessor( + /* Fetches a {@link ComputationState} for a String computationId. */ + Function> computationStateFetcher) { + this.computationStateFetcher = computationStateFetcher; + } + + @Override + public void accept(List responses) { + for (ComputationHeartbeatResponse computationHeartbeatResponse : responses) { + // Maps sharding key to (work token, cache token) for work that should be marked failed. + Multimap failedWork = ArrayListMultimap.create(); + for (HeartbeatResponse heartbeatResponse : + computationHeartbeatResponse.getHeartbeatResponsesList()) { + if (heartbeatResponse.getFailed()) { + failedWork.put( + heartbeatResponse.getShardingKey(), + WorkId.builder() + .setWorkToken(heartbeatResponse.getWorkToken()) + .setCacheToken(heartbeatResponse.getCacheToken()) + .build()); + } + } + + computationStateFetcher + .apply(computationHeartbeatResponse.getComputationId()) + .ifPresent(state -> state.failWork(failedWork)); + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java index 64b6e675ef5f..d7ed83def43e 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java @@ -26,6 +26,7 @@ import java.net.UnknownHostException; import java.util.Map; import java.util.Optional; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillServiceAddress.AuthenticatedGcpServiceAddress; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; @@ -40,23 +41,6 @@ public abstract class WindmillEndpoints { private static final Logger LOG = LoggerFactory.getLogger(WindmillEndpoints.class); - /** - * Used by GetData GlobalDataRequest(s) to support Beam side inputs. Returns a map where the key - * is a global data tag and the value is the endpoint where the data associated with the global - * data tag resides. - * - * @see Beam Side - * Inputs - */ - public abstract ImmutableMap globalDataEndpoints(); - - /** - * Used by GetWork/GetData/CommitWork calls to send, receive, and commit work directly to/from - * Windmill servers. Returns a list of endpoints used to communicate with the corresponding - * Windmill servers. - */ - public abstract ImmutableList windmillEndpoints(); - public static WindmillEndpoints from( Windmill.WorkerMetadataResponse workerMetadataResponseProto) { ImmutableMap globalDataServers = @@ -64,11 +48,16 @@ public static WindmillEndpoints from( .collect( toImmutableMap( Map.Entry::getKey, // global data key - endpoint -> WindmillEndpoints.Endpoint.from(endpoint.getValue()))); + endpoint -> + WindmillEndpoints.Endpoint.from( + endpoint.getValue(), + workerMetadataResponseProto.getExternalEndpoint()))); ImmutableList windmillServers = workerMetadataResponseProto.getWorkEndpointsList().stream() - .map(WindmillEndpoints.Endpoint::from) + .map( + endpointProto -> + Endpoint.from(endpointProto, workerMetadataResponseProto.getExternalEndpoint())) .collect(toImmutableList()); return WindmillEndpoints.builder() @@ -81,6 +70,76 @@ public static WindmillEndpoints.Builder builder() { return new AutoValue_WindmillEndpoints.Builder(); } + private static Optional parseDirectEndpoint( + Windmill.WorkerMetadataResponse.Endpoint endpointProto, String authenticatingService) { + Optional directEndpointIpV6Address = + tryParseDirectEndpointIntoIpV6Address(endpointProto) + .map(address -> AuthenticatedGcpServiceAddress.create(authenticatingService, address)) + .map(WindmillServiceAddress::create); + + return directEndpointIpV6Address.isPresent() + ? directEndpointIpV6Address + : tryParseEndpointIntoHostAndPort(endpointProto.getDirectEndpoint()) + .map(WindmillServiceAddress::create); + } + + private static Optional tryParseEndpointIntoHostAndPort(String directEndpoint) { + try { + return Optional.of(HostAndPort.fromString(directEndpoint)); + } catch (IllegalArgumentException e) { + LOG.warn("{} cannot be parsed into a gcpServiceAddress", directEndpoint); + return Optional.empty(); + } + } + + private static Optional tryParseDirectEndpointIntoIpV6Address( + Windmill.WorkerMetadataResponse.Endpoint endpointProto) { + if (!endpointProto.hasDirectEndpoint()) { + return Optional.empty(); + } + + InetAddress directEndpointAddress; + try { + directEndpointAddress = Inet6Address.getByName(endpointProto.getDirectEndpoint()); + } catch (UnknownHostException e) { + LOG.warn( + "Error occurred trying to parse direct_endpoint={} into IPv6 address. Exception={}", + endpointProto.getDirectEndpoint(), + e.toString()); + return Optional.empty(); + } + + // Inet6Address.getByAddress returns either an IPv4 or an IPv6 address depending on the format + // of the direct_endpoint string. + if (!(directEndpointAddress instanceof Inet6Address)) { + LOG.warn( + "{} is not an IPv6 address. Direct endpoints are expected to be in IPv6 format.", + endpointProto.getDirectEndpoint()); + return Optional.empty(); + } + + return Optional.of( + HostAndPort.fromParts( + directEndpointAddress.getHostAddress(), (int) endpointProto.getPort())); + } + + /** + * Used by GetData GlobalDataRequest(s) to support Beam side inputs. Returns a map where the key + * is a global data tag and the value is the endpoint where the data associated with the global + * data tag resides. + * + * @see Beam Side + * Inputs + */ + public abstract ImmutableMap globalDataEndpoints(); + + /** + * Used by GetWork/GetData/CommitWork calls to send, receive, and commit work directly to/from + * Windmill servers. Returns a list of endpoints used to communicate with the corresponding + * Windmill servers. + */ + public abstract ImmutableList windmillEndpoints(); + /** * Representation of an endpoint in {@link Windmill.WorkerMetadataResponse.Endpoint} proto with * the worker_token field, and direct_endpoint field parsed into a {@link WindmillServiceAddress} @@ -90,31 +149,21 @@ public static WindmillEndpoints.Builder builder() { */ @AutoValue public abstract static class Endpoint { - /** - * {@link WindmillServiceAddress} representation of {@link - * Windmill.WorkerMetadataResponse.Endpoint#getDirectEndpoint()}. The proto's direct_endpoint - * string can be converted to either {@link Inet6Address} or {@link HostAndPort}. - */ - public abstract Optional directEndpoint(); - - /** - * Corresponds to {@link Windmill.WorkerMetadataResponse.Endpoint#getWorkerToken()} in the - * windmill.proto file. - */ - public abstract Optional workerToken(); - public static Endpoint.Builder builder() { return new AutoValue_WindmillEndpoints_Endpoint.Builder(); } - public static Endpoint from(Windmill.WorkerMetadataResponse.Endpoint endpointProto) { + public static Endpoint from( + Windmill.WorkerMetadataResponse.Endpoint endpointProto, String authenticatingService) { Endpoint.Builder endpointBuilder = Endpoint.builder(); - if (endpointProto.hasDirectEndpoint() && !endpointProto.getDirectEndpoint().isEmpty()) { - parseDirectEndpoint(endpointProto.getDirectEndpoint()) + + if (!endpointProto.getDirectEndpoint().isEmpty()) { + parseDirectEndpoint(endpointProto, authenticatingService) .ifPresent(endpointBuilder::setDirectEndpoint); } - if (endpointProto.hasWorkerToken() && !endpointProto.getWorkerToken().isEmpty()) { - endpointBuilder.setWorkerToken(endpointProto.getWorkerToken()); + + if (!endpointProto.getBackendWorkerToken().isEmpty()) { + endpointBuilder.setWorkerToken(endpointProto.getBackendWorkerToken()); } Endpoint endpoint = endpointBuilder.build(); @@ -130,6 +179,19 @@ public static Endpoint from(Windmill.WorkerMetadataResponse.Endpoint endpointPro return endpoint; } + /** + * {@link WindmillServiceAddress} representation of {@link + * Windmill.WorkerMetadataResponse.Endpoint#getDirectEndpoint()}. The proto's direct_endpoint + * string can be converted to either {@link Inet6Address} or {@link HostAndPort}. + */ + public abstract Optional directEndpoint(); + + /** + * Corresponds to {@link Windmill.WorkerMetadataResponse.Endpoint#getBackendWorkerToken()} ()} + * in the windmill.proto file. + */ + public abstract Optional workerToken(); + @AutoValue.Builder public abstract static class Builder { public abstract Builder setDirectEndpoint(WindmillServiceAddress directEndpoint); @@ -176,46 +238,4 @@ public final Builder addAllGlobalDataEndpoints( public abstract WindmillEndpoints build(); } - - private static Optional parseDirectEndpoint(String directEndpoint) { - Optional directEndpointIpV6Address = - tryParseDirectEndpointIntoIpV6Address(directEndpoint).map(WindmillServiceAddress::create); - - return directEndpointIpV6Address.isPresent() - ? directEndpointIpV6Address - : tryParseEndpointIntoHostAndPort(directEndpoint).map(WindmillServiceAddress::create); - } - - private static Optional tryParseEndpointIntoHostAndPort(String directEndpoint) { - try { - return Optional.of(HostAndPort.fromString(directEndpoint)); - } catch (IllegalArgumentException e) { - LOG.warn("{} cannot be parsed into a gcpServiceAddress", directEndpoint); - return Optional.empty(); - } - } - - private static Optional tryParseDirectEndpointIntoIpV6Address( - String directEndpoint) { - InetAddress directEndpointAddress = null; - try { - directEndpointAddress = Inet6Address.getByName(directEndpoint); - } catch (UnknownHostException e) { - LOG.warn( - "Error occurred trying to parse direct_endpoint={} into IPv6 address. Exception={}", - directEndpoint, - e.toString()); - } - - // Inet6Address.getByAddress returns either an IPv4 or an IPv6 address depending on the format - // of the direct_endpoint string. - if (!(directEndpointAddress instanceof Inet6Address)) { - LOG.warn( - "{} is not an IPv6 address. Direct endpoints are expected to be in IPv6 format.", - directEndpoint); - return Optional.empty(); - } - - return Optional.ofNullable((Inet6Address) directEndpointAddress); - } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java index 25581bee2089..c327e68d7e91 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java @@ -19,11 +19,8 @@ import java.io.IOException; import java.io.PrintWriter; -import java.util.List; import java.util.Set; -import java.util.function.Consumer; import org.apache.beam.runners.dataflow.worker.status.StatusDataProvider; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatResponse; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream; @@ -82,9 +79,6 @@ public abstract GetWorkStream getWorkStream( @Override public void appendSummaryHtml(PrintWriter writer) {} - public void setProcessHeartbeatResponses( - Consumer> processHeartbeatResponses) {} - /** Generic Exception type for implementors to use to represent errors while making RPCs. */ public static final class RpcException extends RuntimeException { public RpcException(Throwable cause) { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServiceAddress.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServiceAddress.java index 3ebda8fab8ed..90f93b072673 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServiceAddress.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServiceAddress.java @@ -18,6 +18,7 @@ package org.apache.beam.runners.dataflow.worker.windmill; import com.google.auto.value.AutoOneOf; +import com.google.auto.value.AutoValue; import java.net.Inet6Address; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; @@ -38,8 +39,33 @@ public static WindmillServiceAddress create(HostAndPort gcpServiceAddress) { public abstract HostAndPort gcpServiceAddress(); + public abstract AuthenticatedGcpServiceAddress authenticatedGcpServiceAddress(); + + public static WindmillServiceAddress create( + AuthenticatedGcpServiceAddress authenticatedGcpServiceAddress) { + return AutoOneOf_WindmillServiceAddress.authenticatedGcpServiceAddress( + authenticatedGcpServiceAddress); + } + public enum Kind { IPV6, - GCP_SERVICE_ADDRESS + GCP_SERVICE_ADDRESS, + // TODO(m-trieu): Use for direct connections when ALTS is enabled. + AUTHENTICATED_GCP_SERVICE_ADDRESS + } + + @AutoValue + public abstract static class AuthenticatedGcpServiceAddress { + + public static AuthenticatedGcpServiceAddress create( + String authenticatingService, HostAndPort gcpServiceAddress) { + // HostAndPort supports IpV6. + return new AutoValue_WindmillServiceAddress_AuthenticatedGcpServiceAddress( + authenticatingService, gcpServiceAddress); + } + + public abstract String authenticatingService(); + + public abstract HostAndPort gcpServiceAddress(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java index ef9156f9c050..aa15e0a5e1a6 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java @@ -20,19 +20,22 @@ import static org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillChannelFactory.LOCALHOST; import static org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillChannelFactory.localhostChannel; -import java.util.ArrayList; -import java.util.HashSet; +import com.google.auto.value.AutoValue; import java.util.List; import java.util.Random; import java.util.Set; +import java.util.concurrent.atomic.AtomicReference; import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillMetadataServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillMetadataServiceV1Alpha1Grpc.CloudWindmillMetadataServiceV1Alpha1Stub; import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub; import org.apache.beam.runners.dataflow.worker.windmill.WindmillServiceAddress; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillStubFactory; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; import org.slf4j.Logger; @@ -44,93 +47,178 @@ class GrpcDispatcherClient { private static final Logger LOG = LoggerFactory.getLogger(GrpcDispatcherClient.class); private final WindmillStubFactory windmillStubFactory; - @GuardedBy("this") - private final List dispatcherStubs; - - @GuardedBy("this") - private final Set dispatcherEndpoints; + /** + * Current dispatcher endpoints and stubs used to communicate with Windmill Dispatcher. + * + * @implNote Reads are lock free, writes are synchronized. + */ + private final AtomicReference dispatcherStubs; @GuardedBy("this") private final Random rand; private GrpcDispatcherClient( WindmillStubFactory windmillStubFactory, - List dispatcherStubs, - Set dispatcherEndpoints, + DispatcherStubs initialDispatcherStubs, Random rand) { this.windmillStubFactory = windmillStubFactory; - this.dispatcherStubs = dispatcherStubs; - this.dispatcherEndpoints = dispatcherEndpoints; this.rand = rand; + this.dispatcherStubs = new AtomicReference<>(initialDispatcherStubs); } static GrpcDispatcherClient create(WindmillStubFactory windmillStubFactory) { - return new GrpcDispatcherClient( - windmillStubFactory, new ArrayList<>(), new HashSet<>(), new Random()); + return new GrpcDispatcherClient(windmillStubFactory, DispatcherStubs.empty(), new Random()); } @VisibleForTesting static GrpcDispatcherClient forTesting( WindmillStubFactory windmillGrpcStubFactory, - List dispatcherStubs, + List windmillServiceStubs, + List windmillMetadataServiceStubs, Set dispatcherEndpoints) { - Preconditions.checkArgument(dispatcherEndpoints.size() == dispatcherStubs.size()); + Preconditions.checkArgument( + dispatcherEndpoints.size() == windmillServiceStubs.size() + && windmillServiceStubs.size() == windmillMetadataServiceStubs.size()); return new GrpcDispatcherClient( - windmillGrpcStubFactory, dispatcherStubs, dispatcherEndpoints, new Random()); + windmillGrpcStubFactory, + DispatcherStubs.create( + dispatcherEndpoints, windmillServiceStubs, windmillMetadataServiceStubs), + new Random()); + } + + CloudWindmillServiceV1Alpha1Stub getWindmillServiceStub() { + ImmutableList windmillServiceStubs = + dispatcherStubs.get().windmillServiceStubs(); + Preconditions.checkState( + !windmillServiceStubs.isEmpty(), "windmillServiceEndpoint has not been set"); + + return (windmillServiceStubs.size() == 1 + ? windmillServiceStubs.get(0) + : randomlySelectNextStub(windmillServiceStubs)); } - synchronized CloudWindmillServiceV1Alpha1Stub getDispatcherStub() { + CloudWindmillMetadataServiceV1Alpha1Stub getWindmillMetadataServiceStub() { + ImmutableList windmillMetadataServiceStubs = + dispatcherStubs.get().windmillMetadataServiceStubs(); Preconditions.checkState( - !dispatcherStubs.isEmpty(), "windmillServiceEndpoint has not been set"); + !windmillMetadataServiceStubs.isEmpty(), "windmillServiceEndpoint has not been set"); + + return (windmillMetadataServiceStubs.size() == 1 + ? windmillMetadataServiceStubs.get(0) + : randomlySelectNextStub(windmillMetadataServiceStubs)); + } - return (dispatcherStubs.size() == 1 - ? dispatcherStubs.get(0) - : dispatcherStubs.get(rand.nextInt(dispatcherStubs.size()))); + private synchronized T randomlySelectNextStub(List stubs) { + return stubs.get(rand.nextInt(stubs.size())); } - synchronized boolean isReady() { - return !dispatcherStubs.isEmpty(); + /** + * Returns whether the {@link DispatcherStubs} have been set. Once initially set, {@link + * #dispatcherStubs} will always have a value as empty updates will trigger an {@link + * IllegalStateException}. + */ + boolean hasInitializedEndpoints() { + return dispatcherStubs.get().hasInitializedEndpoints(); } synchronized void consumeWindmillDispatcherEndpoints( ImmutableSet dispatcherEndpoints) { + ImmutableSet currentDispatcherEndpoints = + dispatcherStubs.get().dispatcherEndpoints(); Preconditions.checkArgument( dispatcherEndpoints != null && !dispatcherEndpoints.isEmpty(), "Cannot set dispatcher endpoints to nothing."); - if (this.dispatcherEndpoints.equals(dispatcherEndpoints)) { + if (currentDispatcherEndpoints.equals(dispatcherEndpoints)) { // The endpoints are equal don't recreate the stubs. return; } LOG.info("Creating a new windmill stub, endpoints: {}", dispatcherEndpoints); - if (!this.dispatcherEndpoints.isEmpty()) { - LOG.info("Previous windmill stub endpoints: {}", this.dispatcherEndpoints); + if (!currentDispatcherEndpoints.isEmpty()) { + LOG.info("Previous windmill stub endpoints: {}", currentDispatcherEndpoints); } - resetDispatcherEndpoints(dispatcherEndpoints); + LOG.info("Initializing Streaming Engine GRPC client for endpoints: {}", dispatcherEndpoints); + dispatcherStubs.set(DispatcherStubs.create(dispatcherEndpoints, windmillStubFactory)); } - private synchronized void resetDispatcherEndpoints( - ImmutableSet newDispatcherEndpoints) { - LOG.info("Initializing Streaming Engine GRPC client for endpoints: {}", newDispatcherEndpoints); - this.dispatcherStubs.clear(); - this.dispatcherEndpoints.clear(); - this.dispatcherEndpoints.addAll(newDispatcherEndpoints); + /** + * Endpoints and gRPC stubs used to communicate with the Windmill Dispatcher. {@link + * #dispatcherEndpoints()}, {@link #windmillServiceStubs()}, and {@link + * #windmillMetadataServiceStubs()} collections should all be of the same size. + */ + @AutoValue + abstract static class DispatcherStubs { - dispatcherEndpoints.stream() - .map(this::createDispatcherStubForWindmillService) - .forEach(dispatcherStubs::add); - } + private static DispatcherStubs empty() { + return create(ImmutableSet.of(), ImmutableList.of(), ImmutableList.of()); + } - private CloudWindmillServiceV1Alpha1Stub createDispatcherStubForWindmillService( - HostAndPort endpoint) { - if (LOCALHOST.equals(endpoint.getHost())) { - return CloudWindmillServiceV1Alpha1Grpc.newStub(localhostChannel(endpoint.getPort())); + private static DispatcherStubs create( + Set endpoints, + List windmillServiceStubs, + List windmillMetadataServiceStubs) { + Preconditions.checkState( + endpoints.size() == windmillServiceStubs.size() + && windmillServiceStubs.size() == windmillMetadataServiceStubs.size(), + "Dispatcher should have the same number of endpoints and stubs"); + return new AutoValue_GrpcDispatcherClient_DispatcherStubs( + ImmutableSet.copyOf(endpoints), + ImmutableList.copyOf(windmillServiceStubs), + ImmutableList.copyOf(windmillMetadataServiceStubs)); } - // Use an in-process stub if testing. - return windmillStubFactory.getKind() == WindmillStubFactory.Kind.IN_PROCESS - ? windmillStubFactory.inProcess().get() - : windmillStubFactory.remote().apply(WindmillServiceAddress.create(endpoint)); + private static DispatcherStubs create( + ImmutableSet newDispatcherEndpoints, WindmillStubFactory windmillStubFactory) { + ImmutableList.Builder windmillServiceStubs = + ImmutableList.builder(); + ImmutableList.Builder windmillMetadataServiceStubs = + ImmutableList.builder(); + + for (HostAndPort endpoint : newDispatcherEndpoints) { + windmillServiceStubs.add(createWindmillServiceStub(endpoint, windmillStubFactory)); + windmillMetadataServiceStubs.add( + createWindmillMetadataServiceStub(endpoint, windmillStubFactory)); + } + + return new AutoValue_GrpcDispatcherClient_DispatcherStubs( + newDispatcherEndpoints, + windmillServiceStubs.build(), + windmillMetadataServiceStubs.build()); + } + + private static CloudWindmillServiceV1Alpha1Stub createWindmillServiceStub( + HostAndPort endpoint, WindmillStubFactory windmillStubFactory) { + if (LOCALHOST.equals(endpoint.getHost())) { + return CloudWindmillServiceV1Alpha1Grpc.newStub(localhostChannel(endpoint.getPort())); + } + + return windmillStubFactory.createWindmillServiceStub(WindmillServiceAddress.create(endpoint)); + } + + private static CloudWindmillMetadataServiceV1Alpha1Stub createWindmillMetadataServiceStub( + HostAndPort endpoint, WindmillStubFactory windmillStubFactory) { + if (LOCALHOST.equals(endpoint.getHost())) { + return CloudWindmillMetadataServiceV1Alpha1Grpc.newStub( + localhostChannel(endpoint.getPort())); + } + + return windmillStubFactory.createWindmillMetadataServiceStub( + WindmillServiceAddress.create(endpoint)); + } + + private int size() { + return dispatcherEndpoints().size(); + } + + private boolean hasInitializedEndpoints() { + return size() > 0; + } + + abstract ImmutableSet dispatcherEndpoints(); + + abstract ImmutableList windmillServiceStubs(); + + abstract ImmutableList windmillMetadataServiceStubs(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java index fbed81c1153d..858aeb159856 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java @@ -33,6 +33,8 @@ import javax.annotation.Nullable; import org.apache.beam.runners.dataflow.DataflowRunner; import org.apache.beam.runners.dataflow.worker.options.StreamingDataflowWorkerOptions; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillMetadataServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillMetadataServiceV1Alpha1Grpc.CloudWindmillMetadataServiceV1Alpha1Stub; import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; @@ -53,6 +55,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.RemoteWindmillStubFactory; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillStubFactory; import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.StreamingEngineThrottleTimers; import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemReceiver; @@ -84,14 +87,14 @@ }) @SuppressWarnings("nullness") // TODO(https://github.com/apache/beam/issues/20497 public final class GrpcWindmillServer extends WindmillServerStub { + public static final Duration LOCALHOST_MAX_BACKOFF = Duration.millis(500); + public static final Duration MAX_BACKOFF = Duration.standardSeconds(30); + private static final Duration MIN_BACKOFF = Duration.millis(1); private static final Logger LOG = LoggerFactory.getLogger(GrpcWindmillServer.class); private static final int DEFAULT_LOG_EVERY_N_FAILURES = 20; - private static final Duration MIN_BACKOFF = Duration.millis(1); - private static final Duration MAX_BACKOFF = Duration.standardSeconds(30); private static final int NO_HEALTH_CHECK = -1; private static final String GRPC_LOCALHOST = "grpc:localhost"; - private final GrpcWindmillStreamFactory windmillStreamFactory; private final GrpcDispatcherClient dispatcherClient; private final StreamingDataflowWorkerOptions options; private final StreamingEngineThrottleTimers throttleTimers; @@ -100,44 +103,26 @@ public final class GrpcWindmillServer extends WindmillServerStub { // If true, then active work refreshes will be sent as KeyedGetDataRequests. Otherwise, use the // newer ComputationHeartbeatRequests. private final boolean sendKeyedGetDataRequests; - private Consumer> processHeartbeatResponses; + private final Consumer> processHeartbeatResponses; + private final GrpcWindmillStreamFactory windmillStreamFactory; private GrpcWindmillServer( - StreamingDataflowWorkerOptions options, GrpcDispatcherClient grpcDispatcherClient) { + StreamingDataflowWorkerOptions options, + GrpcWindmillStreamFactory grpcWindmillStreamFactory, + GrpcDispatcherClient grpcDispatcherClient, + Consumer> processHeartbeatResponses) { this.options = options; this.throttleTimers = StreamingEngineThrottleTimers.create(); this.maxBackoff = MAX_BACKOFF; - this.windmillStreamFactory = - GrpcWindmillStreamFactory.of( - JobHeader.newBuilder() - .setJobId(options.getJobId()) - .setProjectId(options.getProject()) - .setWorkerId(options.getWorkerId()) - .build()) - .setWindmillMessagesBetweenIsReadyChecks( - options.getWindmillMessagesBetweenIsReadyChecks()) - .setMaxBackOffSupplier(() -> maxBackoff) - .setLogEveryNStreamFailures( - options.getWindmillServiceStreamingLogEveryNStreamFailures()) - .setStreamingRpcBatchLimit(options.getWindmillServiceStreamingRpcBatchLimit()) - .build(); - windmillStreamFactory.scheduleHealthChecks( - options.getWindmillServiceStreamingRpcHealthCheckPeriodMs()); - this.dispatcherClient = grpcDispatcherClient; this.syncApplianceStub = null; this.sendKeyedGetDataRequests = !options.isEnableStreamingEngine() || !DataflowRunner.hasExperiment( options, "streaming_engine_send_new_heartbeat_requests"); - this.processHeartbeatResponses = (responses) -> {}; - } - - @Override - public void setProcessHeartbeatResponses( - Consumer> processHeartbeatResponses) { this.processHeartbeatResponses = processHeartbeatResponses; - }; + this.windmillStreamFactory = grpcWindmillStreamFactory; + } private static StreamingDataflowWorkerOptions testOptions( boolean enableStreamingEngine, List additionalExperiments) { @@ -162,17 +147,22 @@ private static StreamingDataflowWorkerOptions testOptions( } /** Create new instance of {@link GrpcWindmillServer}. */ - public static GrpcWindmillServer create(StreamingDataflowWorkerOptions workerOptions) + public static GrpcWindmillServer create( + StreamingDataflowWorkerOptions workerOptions, + GrpcWindmillStreamFactory grpcWindmillStreamFactory, + Consumer> processHeartbeatResponses) throws IOException { GrpcWindmillServer grpcWindmillServer = new GrpcWindmillServer( workerOptions, + grpcWindmillStreamFactory, GrpcDispatcherClient.create( - WindmillStubFactory.remoteStubFactory( + new RemoteWindmillStubFactory( workerOptions.getWindmillServiceRpcChannelAliveTimeoutSec(), workerOptions.getGcpCredential(), - workerOptions.getUseWindmillIsolatedChannels()))); + workerOptions.getUseWindmillIsolatedChannels())), + processHeartbeatResponses); if (workerOptions.getWindmillServiceEndpoint() != null) { grpcWindmillServer.configureWindmillServiceEndpoints(); } else if (!workerOptions.isEnableStreamingEngine() @@ -184,32 +174,62 @@ public static GrpcWindmillServer create(StreamingDataflowWorkerOptions workerOpt } @VisibleForTesting - static GrpcWindmillServer newTestInstance(String name, List experiments) { + static GrpcWindmillServer newTestInstance( + String name, + List experiments, + long clientId, + WindmillStubFactory windmillStubFactory) { ManagedChannel inProcessChannel = inProcessChannel(name); CloudWindmillServiceV1Alpha1Stub stub = CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel); - List dispatcherStubs = Lists.newArrayList(stub); + CloudWindmillMetadataServiceV1Alpha1Stub metadataStub = + CloudWindmillMetadataServiceV1Alpha1Grpc.newStub(inProcessChannel); + List windmillServiceStubs = Lists.newArrayList(stub); + List windmillMetadataServiceStubs = + Lists.newArrayList(metadataStub); + Set dispatcherEndpoints = Sets.newHashSet(HostAndPort.fromHost(name)); GrpcDispatcherClient dispatcherClient = GrpcDispatcherClient.forTesting( - WindmillStubFactory.inProcessStubFactory(name, unused -> inProcessChannel), - dispatcherStubs, + windmillStubFactory, + windmillServiceStubs, + windmillMetadataServiceStubs, dispatcherEndpoints); - return new GrpcWindmillServer( - testOptions(/* enableStreamingEngine= */ true, experiments), dispatcherClient); + + StreamingDataflowWorkerOptions testOptions = + testOptions(/* enableStreamingEngine= */ true, experiments); + GrpcWindmillStreamFactory windmillStreamFactory = + GrpcWindmillStreamFactory.of(createJobHeader(testOptions, clientId)).build(); + windmillStreamFactory.scheduleHealthChecks( + testOptions.getWindmillServiceStreamingRpcHealthCheckPeriodMs()); + return new GrpcWindmillServer(testOptions, windmillStreamFactory, dispatcherClient, noop -> {}); } @VisibleForTesting - static GrpcWindmillServer newApplianceTestInstance(Channel channel) { + static GrpcWindmillServer newApplianceTestInstance( + Channel channel, WindmillStubFactory windmillStubFactory) { + StreamingDataflowWorkerOptions options = + testOptions(/* enableStreamingEngine= */ false, new ArrayList<>()); GrpcWindmillServer testServer = new GrpcWindmillServer( - testOptions(/* enableStreamingEngine= */ false, new ArrayList<>()), + options, + GrpcWindmillStreamFactory.of(createJobHeader(options, 1)).build(), // No-op, Appliance does not use Dispatcher to call Streaming Engine. - GrpcDispatcherClient.create(WindmillStubFactory.inProcessStubFactory("test"))); + GrpcDispatcherClient.create(windmillStubFactory), + noop -> {}); testServer.syncApplianceStub = createWindmillApplianceStubWithDeadlineInterceptor(channel); return testServer; } + private static JobHeader createJobHeader(StreamingDataflowWorkerOptions options, long clientId) { + return Windmill.JobHeader.newBuilder() + .setJobId(options.getJobId()) + .setProjectId(options.getProject()) + .setWorkerId(options.getWorkerId()) + .setClientId(clientId) + .build(); + } + private static WindmillApplianceGrpc.WindmillApplianceBlockingStub createWindmillApplianceStubWithDeadlineInterceptor(Channel channel) { return WindmillApplianceGrpc.newBlockingStub(channel) @@ -249,7 +269,7 @@ public void setWindmillServiceEndpoints(Set endpoints) { @Override public boolean isReady() { - return dispatcherClient.isReady(); + return dispatcherClient.hasInitializedEndpoints(); } private synchronized void initializeLocalHost(int port) { @@ -329,7 +349,7 @@ public CommitWorkResponse commitWork(CommitWorkRequest request) { @Override public GetWorkStream getWorkStream(GetWorkRequest request, WorkItemReceiver receiver) { return windmillStreamFactory.createGetWorkStream( - dispatcherClient.getDispatcherStub(), + dispatcherClient.getWindmillServiceStub(), GetWorkRequest.newBuilder(request) .setJobId(options.getJobId()) .setProjectId(options.getProject()) @@ -342,7 +362,7 @@ public GetWorkStream getWorkStream(GetWorkRequest request, WorkItemReceiver rece @Override public GetDataStream getDataStream() { return windmillStreamFactory.createGetDataStream( - dispatcherClient.getDispatcherStub(), + dispatcherClient.getWindmillServiceStub(), throttleTimers.getDataThrottleTimer(), sendKeyedGetDataRequests, this.processHeartbeatResponses); @@ -351,7 +371,7 @@ public GetDataStream getDataStream() { @Override public CommitWorkStream commitWorkStream() { return windmillStreamFactory.createCommitWorkStream( - dispatcherClient.getDispatcherStub(), throttleTimers.commitWorkThrottleTimer()); + dispatcherClient.getWindmillServiceStub(), throttleTimers.commitWorkThrottleTimer()); } @Override diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java index 7dc43e791e31..8696c464a0ff 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java @@ -32,6 +32,7 @@ import java.util.function.Supplier; import javax.annotation.concurrent.ThreadSafe; import org.apache.beam.runners.dataflow.worker.status.StatusDataProvider; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillMetadataServiceV1Alpha1Grpc.CloudWindmillMetadataServiceV1Alpha1Stub; import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatResponse; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest; @@ -49,6 +50,7 @@ import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.util.BackOff; import org.apache.beam.sdk.util.FluentBackoff; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.AbstractStub; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers; import org.joda.time.Duration; import org.joda.time.Instant; @@ -109,8 +111,7 @@ public static GrpcWindmillStreamFactory.Builder of(JobHeader jobHeader) { .setStreamingRpcBatchLimit(DEFAULT_STREAMING_RPC_BATCH_LIMIT); } - private static CloudWindmillServiceV1Alpha1Stub withDeadline( - CloudWindmillServiceV1Alpha1Stub stub) { + private static > T withDefaultDeadline(T stub) { // Deadlines are absolute points in time, so generate a new one everytime this function is // called. return stub.withDeadlineAfter( @@ -123,7 +124,7 @@ public GetWorkStream createGetWorkStream( ThrottleTimer getWorkThrottleTimer, WorkItemReceiver processWorkItem) { return GrpcGetWorkStream.create( - responseObserver -> withDeadline(stub).getWorkStream(responseObserver), + responseObserver -> withDefaultDeadline(stub).getWorkStream(responseObserver), request, grpcBackOff.get(), newStreamObserverFactory(), @@ -141,7 +142,7 @@ public GetWorkStream createDirectGetWorkStream( Supplier commitWorkStream, WorkItemProcessor workItemProcessor) { return GrpcDirectGetWorkStream.create( - responseObserver -> withDeadline(stub).getWorkStream(responseObserver), + responseObserver -> withDefaultDeadline(stub).getWorkStream(responseObserver), request, grpcBackOff.get(), newStreamObserverFactory(), @@ -159,7 +160,7 @@ public GetDataStream createGetDataStream( boolean sendKeyedGetDataRequests, Consumer> processHeartbeatResponses) { return GrpcGetDataStream.create( - responseObserver -> withDeadline(stub).getDataStream(responseObserver), + responseObserver -> withDefaultDeadline(stub).getDataStream(responseObserver), grpcBackOff.get(), newStreamObserverFactory(), streamRegistry, @@ -180,7 +181,7 @@ public GetDataStream createGetDataStream( public CommitWorkStream createCommitWorkStream( CloudWindmillServiceV1Alpha1Stub stub, ThrottleTimer commitWorkThrottleTimer) { return GrpcCommitWorkStream.create( - responseObserver -> withDeadline(stub).commitWorkStream(responseObserver), + responseObserver -> withDefaultDeadline(stub).commitWorkStream(responseObserver), grpcBackOff.get(), newStreamObserverFactory(), streamRegistry, @@ -192,11 +193,11 @@ public CommitWorkStream createCommitWorkStream( } public GetWorkerMetadataStream createGetWorkerMetadataStream( - CloudWindmillServiceV1Alpha1Stub stub, + CloudWindmillMetadataServiceV1Alpha1Stub stub, ThrottleTimer getWorkerMetadataThrottleTimer, Consumer onNewWindmillEndpoints) { return GrpcGetWorkerMetadataStream.create( - responseObserver -> withDeadline(stub).getWorkerMetadataStream(responseObserver), + responseObserver -> withDefaultDeadline(stub).getWorkerMetadata(responseObserver), grpcBackOff.get(), newStreamObserverFactory(), streamRegistry, diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClient.java index 01783f6aa4d3..80c957996ab7 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClient.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClient.java @@ -131,7 +131,7 @@ private StreamingEngineClient( Suppliers.memoize( () -> streamFactory.createGetWorkerMetadataStream( - dispatcherClient.getDispatcherStub(), + dispatcherClient.getWindmillMetadataServiceStub(), getWorkerMetadataThrottleTimer, endpoints -> // Run this on a separate thread than the grpc stream thread. @@ -267,7 +267,7 @@ private synchronized void consumeWindmillWorkerEndpoints(WindmillEndpoints newWi getWorkBudgetRefresher.requestBudgetRefresh(); } - public final ImmutableList getAndResetThrottleTimes() { + public ImmutableList getAndResetThrottleTimes() { StreamingEngineConnectionState currentConnections = connections.get(); ImmutableList keyedWorkStreamThrottleTimes = @@ -375,21 +375,10 @@ private WindmillStreamSender createAndStartWindmillStreamSenderFor( } private CloudWindmillServiceV1Alpha1Stub createWindmillStub(Endpoint endpoint) { - switch (stubFactory.getKind()) { - // This is only used in tests. - case IN_PROCESS: - return stubFactory.inProcess().get(); - // Create stub for direct_endpoint or just default to Dispatcher stub. - case REMOTE: - return endpoint - .directEndpoint() - .map(stubFactory.remote()) - .orElseGet(dispatcherClient::getDispatcherStub); - // Should never be called, this switch statement is exhaustive. - default: - throw new UnsupportedOperationException( - "Only remote or in-process stub factories are available."); - } + return endpoint + .directEndpoint() + .map(stubFactory::createWindmillServiceStub) + .orElseGet(dispatcherClient::getWindmillServiceStub); } private static class StreamingEngineClientException extends IllegalStateException { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/RemoteWindmillStubFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/RemoteWindmillStubFactory.java new file mode 100644 index 000000000000..9978b74c7aa4 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/RemoteWindmillStubFactory.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs; + +import static org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillChannelFactory.remoteChannel; + +import com.google.auth.Credentials; +import java.util.function.Supplier; +import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillMetadataServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillMetadataServiceV1Alpha1Grpc.CloudWindmillMetadataServiceV1Alpha1Stub; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillServiceAddress; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.auth.VendoredCredentialsAdapter; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.auth.MoreCallCredentials; + +/** Creates remote stubs to talk to Streaming Engine. */ +@Internal +@ThreadSafe +public final class RemoteWindmillStubFactory implements WindmillStubFactory { + private final int rpcChannelTimeoutSec; + private final Credentials gcpCredentials; + private final boolean useIsolatedChannels; + + public RemoteWindmillStubFactory( + int rpcChannelTimeoutSec, Credentials gcpCredentials, boolean useIsolatedChannels) { + this.rpcChannelTimeoutSec = rpcChannelTimeoutSec; + this.gcpCredentials = gcpCredentials; + this.useIsolatedChannels = useIsolatedChannels; + } + + @Override + public CloudWindmillServiceV1Alpha1Stub createWindmillServiceStub( + WindmillServiceAddress serviceAddress) { + CloudWindmillServiceV1Alpha1Stub windmillServiceStub = + CloudWindmillServiceV1Alpha1Grpc.newStub(createChannel(serviceAddress)); + return serviceAddress.getKind() != WindmillServiceAddress.Kind.AUTHENTICATED_GCP_SERVICE_ADDRESS + ? windmillServiceStub.withCallCredentials( + MoreCallCredentials.from(new VendoredCredentialsAdapter(gcpCredentials))) + : windmillServiceStub; + } + + @Override + public CloudWindmillMetadataServiceV1Alpha1Stub createWindmillMetadataServiceStub( + WindmillServiceAddress serviceAddress) { + return CloudWindmillMetadataServiceV1Alpha1Grpc.newStub(createChannel(serviceAddress)) + .withCallCredentials( + MoreCallCredentials.from(new VendoredCredentialsAdapter(gcpCredentials))); + } + + private ManagedChannel createChannel(WindmillServiceAddress serviceAddress) { + Supplier channelFactory = + () -> remoteChannel(serviceAddress, rpcChannelTimeoutSec); + // IsolationChannel will create and manage separate RPC channels to the same serviceAddress via + // calling the channelFactory, else just directly return the RPC channel. + return useIsolatedChannels ? IsolationChannel.create(channelFactory) : channelFactory.get(); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java index 68c82c5907b7..cf31436d3647 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java @@ -22,8 +22,11 @@ import java.util.concurrent.TimeUnit; import javax.net.ssl.SSLException; import org.apache.beam.runners.dataflow.worker.windmill.WindmillServiceAddress; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillServiceAddress.AuthenticatedGcpServiceAddress; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Channel; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ForwardingChannelBuilder2; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.alts.AltsChannelBuilder; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.netty.GrpcSslContexts; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.netty.NegotiationType; @@ -57,12 +60,28 @@ static ManagedChannel remoteChannel( return remoteChannel( windmillServiceAddress.gcpServiceAddress(), windmillServiceRpcChannelTimeoutSec); // switch is exhaustive will never happen. + case AUTHENTICATED_GCP_SERVICE_ADDRESS: + return remoteDirectChannel( + windmillServiceAddress.authenticatedGcpServiceAddress(), + windmillServiceRpcChannelTimeoutSec); default: throw new UnsupportedOperationException( - "Only IPV6 and GCP_SERVICE_ADDRESS are supported WindmillServiceAddresses."); + "Only IPV6, GCP_SERVICE_ADDRESS, AUTHENTICATED_GCP_SERVICE_ADDRESS are supported WindmillServiceAddresses."); } } + static ManagedChannel remoteDirectChannel( + AuthenticatedGcpServiceAddress authenticatedGcpServiceAddress, + int windmillServiceRpcChannelTimeoutSec) { + return withDefaultChannelOptions( + AltsChannelBuilder.forAddress( + authenticatedGcpServiceAddress.gcpServiceAddress().getHost(), + authenticatedGcpServiceAddress.gcpServiceAddress().getPort()) + .overrideAuthority(authenticatedGcpServiceAddress.authenticatingService()), + windmillServiceRpcChannelTimeoutSec) + .build(); + } + public static ManagedChannel remoteChannel( HostAndPort endpoint, int windmillServiceRpcChannelTimeoutSec) { try { @@ -100,6 +119,17 @@ public static ManagedChannel remoteChannel( private static ManagedChannel createRemoteChannel( NettyChannelBuilder channelBuilder, int windmillServiceRpcChannelTimeoutSec) throws SSLException { + return withDefaultChannelOptions(channelBuilder, windmillServiceRpcChannelTimeoutSec) + .flowControlWindow(10 * 1024 * 1024) + .negotiationType(NegotiationType.TLS) + // Set ciphers(null) to not use GCM, which is disabled for Dataflow + // due to it being horribly slow. + .sslContext(GrpcSslContexts.forClient().ciphers(null).build()) + .build(); + } + + private static > T withDefaultChannelOptions( + T channelBuilder, int windmillServiceRpcChannelTimeoutSec) { if (windmillServiceRpcChannelTimeoutSec > 0) { channelBuilder .keepAliveTime(windmillServiceRpcChannelTimeoutSec, TimeUnit.SECONDS) @@ -108,14 +138,8 @@ private static ManagedChannel createRemoteChannel( } return channelBuilder - .flowControlWindow(10 * 1024 * 1024) .maxInboundMessageSize(Integer.MAX_VALUE) - .maxInboundMetadataSize(1024 * 1024) - .negotiationType(NegotiationType.TLS) - // Set ciphers(null) to not use GCM, which is disabled for Dataflow - // due to it being horribly slow. - .sslContext(GrpcSslContexts.forClient().ciphers(null).build()) - .build(); + .maxInboundMetadataSize(1024 * 1024); } public static class WindmillChannelCreationException extends IllegalStateException { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillStubFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillStubFactory.java index 7ad46a21c08c..e5e523445a64 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillStubFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillStubFactory.java @@ -17,62 +17,16 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs; -import static org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillChannelFactory.remoteChannel; - -import com.google.auth.Credentials; -import com.google.auto.value.AutoOneOf; -import java.util.function.Function; -import java.util.function.Supplier; -import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillMetadataServiceV1Alpha1Grpc.CloudWindmillMetadataServiceV1Alpha1Stub; import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub; import org.apache.beam.runners.dataflow.worker.windmill.WindmillServiceAddress; -import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.auth.VendoredCredentialsAdapter; -import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; -import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.auth.MoreCallCredentials; - -/** - * Used to create stubs to talk to Streaming Engine. Stubs are either in-process for testing, or - * remote. - */ -@AutoOneOf(WindmillStubFactory.Kind.class) -public abstract class WindmillStubFactory { - - public static WindmillStubFactory inProcessStubFactory( - String testName, Function channelFactory) { - return AutoOneOf_WindmillStubFactory.inProcess( - () -> CloudWindmillServiceV1Alpha1Grpc.newStub(channelFactory.apply(testName))); - } - - public static WindmillStubFactory inProcessStubFactory(String testName) { - return AutoOneOf_WindmillStubFactory.inProcess( - () -> - CloudWindmillServiceV1Alpha1Grpc.newStub( - WindmillChannelFactory.inProcessChannel(testName))); - } - - public static WindmillStubFactory remoteStubFactory( - int rpcChannelTimeoutSec, Credentials gcpCredentials, boolean useIsolatedChannels) { - return AutoOneOf_WindmillStubFactory.remote( - directEndpoint -> { - Supplier channelSupplier = - () -> remoteChannel(directEndpoint, rpcChannelTimeoutSec); - return CloudWindmillServiceV1Alpha1Grpc.newStub( - useIsolatedChannels - ? IsolationChannel.create(channelSupplier) - : channelSupplier.get()) - .withCallCredentials( - MoreCallCredentials.from(new VendoredCredentialsAdapter(gcpCredentials))); - }); - } - - public abstract Kind getKind(); - - public abstract Supplier inProcess(); +import org.apache.beam.sdk.annotations.Internal; - public abstract Function remote(); +/** Used to create stubs to talk to Streaming Engine. */ +@Internal +public interface WindmillStubFactory { + CloudWindmillServiceV1Alpha1Stub createWindmillServiceStub(WindmillServiceAddress serviceAddress); - public enum Kind { - IN_PROCESS, - REMOTE - } + CloudWindmillMetadataServiceV1Alpha1Stub createWindmillMetadataServiceStub( + WindmillServiceAddress serviceAddress); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java index 2cfec6d3139a..069fcac07c80 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java @@ -32,6 +32,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Queue; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; @@ -42,6 +43,8 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; import java.util.function.Function; +import org.apache.beam.runners.dataflow.worker.streaming.ComputationState; +import org.apache.beam.runners.dataflow.worker.streaming.WorkHeartbeatResponseProcessor; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitWorkResponse; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationCommitWorkRequest; @@ -70,7 +73,7 @@ import org.slf4j.LoggerFactory; /** An in-memory Windmill server that offers provided work and data. */ -class FakeWindmillServer extends WindmillServerStub { +final class FakeWindmillServer extends WindmillServerStub { private static final Logger LOG = LoggerFactory.getLogger(FakeWindmillServer.class); private final ResponseQueue workToOffer; private final ResponseQueue dataToOffer; @@ -86,9 +89,11 @@ class FakeWindmillServer extends WindmillServerStub { private final List getDataRequests = new ArrayList<>(); private boolean isReady = true; private boolean dropStreamingCommits = false; - private Consumer> processHeartbeatResponses; + private final Consumer> processHeartbeatResponses; - public FakeWindmillServer(ErrorCollector errorCollector) { + public FakeWindmillServer( + ErrorCollector errorCollector, + Function> computationStateFetcher) { workToOffer = new ResponseQueue() .returnByDefault(Windmill.GetWorkResponse.getDefaultInstance()); @@ -106,13 +111,7 @@ public FakeWindmillServer(ErrorCollector errorCollector) { this.errorCollector = errorCollector; statsReceived = new ArrayList<>(); droppedStreamingCommits = new ConcurrentHashMap<>(); - processHeartbeatResponses = (responses) -> {}; - } - - @Override - public void setProcessHeartbeatResponses( - Consumer> processHeartbeatResponses) { - this.processHeartbeatResponses = processHeartbeatResponses; + this.processHeartbeatResponses = new WorkHeartbeatResponseProcessor(computationStateFetcher); } public void setDropStreamingCommits(boolean dropStreamingCommits) { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java index abd2cbbac6ef..df806fcb9786 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java @@ -70,8 +70,10 @@ import java.util.Map; import java.util.Optional; import java.util.PriorityQueue; +import java.util.Random; import java.util.concurrent.Callable; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executors; @@ -179,6 +181,7 @@ import org.joda.time.Duration; import org.joda.time.Instant; import org.junit.Assert; +import org.junit.Before; import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; @@ -261,6 +264,7 @@ public Long get() { return idGenerator.getAndIncrement(); } }; + @Rule public transient Timeout globalTimeout = Timeout.seconds(600); @Rule public BlockingFn blockingFn = new BlockingFn(); @Rule public TestRule restoreMDC = new RestoreDataflowLoggingMDC(); @@ -268,6 +272,10 @@ public Long get() { WorkUnitClient mockWorkUnitClient = mock(WorkUnitClient.class); HotKeyLogger hotKeyLogger = mock(HotKeyLogger.class); + private final ConcurrentMap computationMap = new ConcurrentHashMap<>(); + private final FakeWindmillServer server = + new FakeWindmillServer(errorCollector, id -> Optional.ofNullable(computationMap.get(id))); + public StreamingDataflowWorkerTest(Boolean streamingEngine) { this.streamingEngine = streamingEngine; } @@ -286,6 +294,12 @@ private static CounterUpdate getCounter(Iterable counters, String return null; } + @Before + public void setUp() { + computationMap.clear(); + server.clearCommitsReceived(); + } + static Work createMockWork(long workToken) { return createMockWork(workToken, work -> {}); } @@ -760,8 +774,7 @@ private ByteString addPaneTag(PaneInfo pane, byte[] windowBytes) throws IOExcept return output.toByteString(); } - private StreamingDataflowWorkerOptions createTestingPipelineOptions( - FakeWindmillServer server, String... args) { + private StreamingDataflowWorkerOptions createTestingPipelineOptions(String... args) { List argsList = Lists.newArrayList(args); if (streamingEngine) { argsList.add("--experiments=enable_streaming_engine"); @@ -771,8 +784,9 @@ private StreamingDataflowWorkerOptions createTestingPipelineOptions( .as(StreamingDataflowWorkerOptions.class); options.setAppName("StreamingWorkerHarnessTest"); options.setJobId("test_job_id"); + options.setProject("test_project"); + options.setWorkerId("test_worker"); options.setStreaming(true); - options.setWindmillServerStub(server); options.setActiveWorkRefreshPeriodMillis(0); return options; } @@ -782,10 +796,12 @@ private StreamingDataflowWorker makeWorker( StreamingDataflowWorkerOptions options, boolean publishCounters, Supplier clock, - Function executorSupplier) - throws Exception { + Function executorSupplier) { StreamingDataflowWorker worker = new StreamingDataflowWorker( + server, + new Random().nextLong(), + computationMap, Collections.singletonList(defaultMapTask(instructions)), IntrinsicMapTaskExecutorFactory.defaultFactory(), mockWorkUnitClient, @@ -802,8 +818,7 @@ private StreamingDataflowWorker makeWorker( private StreamingDataflowWorker makeWorker( List instructions, StreamingDataflowWorkerOptions options, - boolean publishCounters) - throws Exception { + boolean publishCounters) { return makeWorker( instructions, options, @@ -819,9 +834,8 @@ public void testBasicHarness() throws Exception { makeSourceInstruction(StringUtf8Coder.of()), makeSinkInstruction(StringUtf8Coder.of(), 0)); - FakeWindmillServer server = new FakeWindmillServer(errorCollector); - StreamingDataflowWorkerOptions options = createTestingPipelineOptions(server); - StreamingDataflowWorker worker = makeWorker(instructions, options, true /* publishCounters */); + StreamingDataflowWorker worker = + makeWorker(instructions, createTestingPipelineOptions(), true /* publishCounters */); worker.start(); final int numIters = 2000; @@ -848,9 +862,7 @@ private void runTestBasic(int numCommitThreads) throws Exception { makeSourceInstruction(StringUtf8Coder.of()), makeSinkInstruction(StringUtf8Coder.of(), 0)); - FakeWindmillServer server = new FakeWindmillServer(errorCollector); server.setIsReady(false); - StreamingConfigTask streamingConfig = new StreamingConfigTask(); streamingConfig.setStreamingComputationConfigs( ImmutableList.of(makeDefaultStreamingComputationConfig(instructions))); @@ -859,9 +871,8 @@ private void runTestBasic(int numCommitThreads) throws Exception { workItem.setStreamingConfigTask(streamingConfig); when(mockWorkUnitClient.getGlobalStreamingConfigWorkItem()).thenReturn(Optional.of(workItem)); - StreamingDataflowWorkerOptions options = createTestingPipelineOptions(server); - options.setWindmillServiceCommitThreads(numCommitThreads); - StreamingDataflowWorker worker = makeWorker(instructions, options, true /* publishCounters */); + StreamingDataflowWorker worker = + makeWorker(instructions, createTestingPipelineOptions(), true /* publishCounters */); worker.start(); final int numIters = 2000; @@ -900,7 +911,6 @@ public void testHotKeyLogging() throws Exception { makeSourceInstruction(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())), makeSinkInstruction(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()), 0)); - FakeWindmillServer server = new FakeWindmillServer(errorCollector); server.setIsReady(false); StreamingConfigTask streamingConfig = new StreamingConfigTask(); @@ -911,9 +921,11 @@ public void testHotKeyLogging() throws Exception { workItem.setStreamingConfigTask(streamingConfig); when(mockWorkUnitClient.getGlobalStreamingConfigWorkItem()).thenReturn(Optional.of(workItem)); - StreamingDataflowWorkerOptions options = - createTestingPipelineOptions(server, "--hotKeyLoggingEnabled=true"); - StreamingDataflowWorker worker = makeWorker(instructions, options, true /* publishCounters */); + StreamingDataflowWorker worker = + makeWorker( + instructions, + createTestingPipelineOptions("--hotKeyLoggingEnabled=true"), + true /* publishCounters */); worker.start(); final int numIters = 2000; @@ -938,7 +950,6 @@ public void testHotKeyLoggingNotEnabled() throws Exception { makeSourceInstruction(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())), makeSinkInstruction(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()), 0)); - FakeWindmillServer server = new FakeWindmillServer(errorCollector); server.setIsReady(false); StreamingConfigTask streamingConfig = new StreamingConfigTask(); @@ -949,8 +960,8 @@ public void testHotKeyLoggingNotEnabled() throws Exception { workItem.setStreamingConfigTask(streamingConfig); when(mockWorkUnitClient.getGlobalStreamingConfigWorkItem()).thenReturn(Optional.of(workItem)); - StreamingDataflowWorkerOptions options = createTestingPipelineOptions(server); - StreamingDataflowWorker worker = makeWorker(instructions, options, true /* publishCounters */); + StreamingDataflowWorker worker = + makeWorker(instructions, createTestingPipelineOptions(), true /* publishCounters */); worker.start(); final int numIters = 2000; @@ -975,10 +986,8 @@ public void testIgnoreRetriedKeys() throws Exception { makeDoFnInstruction(blockingFn, 0, StringUtf8Coder.of()), makeSinkInstruction(StringUtf8Coder.of(), 0)); - FakeWindmillServer server = new FakeWindmillServer(errorCollector); - StreamingDataflowWorkerOptions options = createTestingPipelineOptions(server); - - StreamingDataflowWorker worker = makeWorker(instructions, options, true /* publishCounters */); + StreamingDataflowWorker worker = + makeWorker(instructions, createTestingPipelineOptions(), true /* publishCounters */); worker.start(); for (int i = 0; i < numIters; ++i) { @@ -1099,8 +1108,7 @@ public void testNumberOfWorkerHarnessThreadsIsHonored() throws Exception { makeDoFnInstruction(blockingFn, 0, StringUtf8Coder.of()), makeSinkInstruction(StringUtf8Coder.of(), 0)); - FakeWindmillServer server = new FakeWindmillServer(errorCollector); - StreamingDataflowWorkerOptions options = createTestingPipelineOptions(server); + StreamingDataflowWorkerOptions options = createTestingPipelineOptions(); options.setNumberOfWorkerHarnessThreads(expectedNumberOfThreads); StreamingDataflowWorker worker = makeWorker(instructions, options, true /* publishCounters */); @@ -1143,13 +1151,12 @@ public void testKeyTokenInvalidException() throws Exception { makeDoFnInstruction(new KeyTokenInvalidFn(), 0, kvCoder), makeSinkInstruction(kvCoder, 1)); - FakeWindmillServer server = new FakeWindmillServer(errorCollector); server .whenGetWorkCalled() .thenReturn(makeInput(0, 0, DEFAULT_KEY_STRING, DEFAULT_SHARDING_KEY)); StreamingDataflowWorker worker = - makeWorker(instructions, createTestingPipelineOptions(server), true /* publishCounters */); + makeWorker(instructions, createTestingPipelineOptions(), true /* publishCounters */); worker.start(); server.waitForEmptyWorkQueue(); @@ -1177,11 +1184,10 @@ public void testKeyCommitTooLargeException() throws Exception { makeDoFnInstruction(new LargeCommitFn(), 0, kvCoder), makeSinkInstruction(kvCoder, 1)); - FakeWindmillServer server = new FakeWindmillServer(errorCollector); server.setExpectedExceptionCount(1); StreamingDataflowWorker worker = - makeWorker(instructions, createTestingPipelineOptions(server), true /* publishCounters */); + makeWorker(instructions, createTestingPipelineOptions(), true /* publishCounters */); worker.setMaxWorkItemCommitBytes(1000); worker.start(); @@ -1245,7 +1251,6 @@ public void testKeyChange() throws Exception { makeDoFnInstruction(new ChangeKeysFn(), 0, kvCoder), makeSinkInstruction(kvCoder, 1)); - FakeWindmillServer server = new FakeWindmillServer(errorCollector); for (int i = 0; i < 2; i++) { server .whenGetWorkCalled() @@ -1261,7 +1266,7 @@ public void testKeyChange() throws Exception { } StreamingDataflowWorker worker = - makeWorker(instructions, createTestingPipelineOptions(server), true /* publishCounters */); + makeWorker(instructions, createTestingPipelineOptions(), true /* publishCounters */); worker.start(); Map result = server.waitForAndGetCommits(4); @@ -1302,7 +1307,6 @@ public void testExceptions() throws Exception { makeDoFnInstruction(new TestExceptionFn(), 0, StringUtf8Coder.of()), makeSinkInstruction(StringUtf8Coder.of(), 1)); - FakeWindmillServer server = new FakeWindmillServer(errorCollector); server.setExpectedExceptionCount(1); String keyString = keyStringForIndex(0); server @@ -1337,7 +1341,7 @@ public void testExceptions() throws Exception { Collections.singletonList(DEFAULT_WINDOW)))); StreamingDataflowWorker worker = - makeWorker(instructions, createTestingPipelineOptions(server), true /* publishCounters */); + makeWorker(instructions, createTestingPipelineOptions(), true /* publishCounters */); worker.start(); server.waitForEmptyWorkQueue(); @@ -1433,8 +1437,6 @@ public void testAssignWindows() throws Exception { addWindowsInstruction, makeSinkInstruction(StringUtf8Coder.of(), 1)); - FakeWindmillServer server = new FakeWindmillServer(errorCollector); - int timestamp1 = 0; int timestamp2 = 1000000; @@ -1444,7 +1446,7 @@ public void testAssignWindows() throws Exception { .thenReturn(makeInput(timestamp2, timestamp2)); StreamingDataflowWorker worker = - makeWorker(instructions, createTestingPipelineOptions(server), false /* publishCounters */); + makeWorker(instructions, createTestingPipelineOptions(), false /* publishCounters */); worker.start(); Map result = server.waitForAndGetCommits(2); @@ -1561,10 +1563,8 @@ public void testMergeWindows() throws Exception { mergeWindowsInstruction, makeSinkInstruction(groupedCoder, 1)); - FakeWindmillServer server = new FakeWindmillServer(errorCollector); - StreamingDataflowWorker worker = - makeWorker(instructions, createTestingPipelineOptions(server), false /* publishCounters */); + makeWorker(instructions, createTestingPipelineOptions(), false /* publishCounters */); Map nameMap = new HashMap<>(); nameMap.put("MergeWindowsStep", "MergeWindows"); worker.addStateNameMappings(nameMap); @@ -1850,10 +1850,8 @@ public void testMergeWindowsCaching() throws Exception { makeDoFnInstruction(new PassthroughDoFn(), 1, groupedCoder), makeSinkInstruction(groupedCoder, 2)); - FakeWindmillServer server = new FakeWindmillServer(errorCollector); - StreamingDataflowWorker worker = - makeWorker(instructions, createTestingPipelineOptions(server), false /* publishCounters */); + makeWorker(instructions, createTestingPipelineOptions(), false /* publishCounters */); Map nameMap = new HashMap<>(); nameMap.put("MergeWindowsStep", "MergeWindows"); worker.addStateNameMappings(nameMap); @@ -2148,10 +2146,8 @@ private void runMergeSessionsActions(List actions) throws Exception { mergeWindowsInstruction, makeSinkInstruction(groupedCoder, 1)); - FakeWindmillServer server = new FakeWindmillServer(errorCollector); - StreamingDataflowWorker worker = - makeWorker(instructions, createTestingPipelineOptions(server), false /* publishCounters */); + makeWorker(instructions, createTestingPipelineOptions(), false /* publishCounters */); Map nameMap = new HashMap<>(); nameMap.put("MergeWindowsStep", "MergeWindows"); worker.addStateNameMappings(nameMap); @@ -2172,8 +2168,7 @@ private void runMergeSessionsActions(List actions) throws Exception { } @Test - public void testMergeSessionWindows() throws Exception { - // Test a single late window. + public void testMergeSessionWindows_singleLateWindow() throws Exception { runMergeSessionsActions( Collections.singletonList( new Action( @@ -2183,7 +2178,10 @@ public void testMergeSessionWindows() throws Exception { buildHold("/gAAAAAAAAAsK/+uhold", -1, true), buildHold("/gAAAAAAAAAsK/+uextra", -1, true)) .withTimers(buildWatermarkTimer("/s/gAAAAAAAAAsK/+0", 3600010)))); + } + @Test + public void testMergeSessionWindows() throws Exception { // Test the behavior with an: // - on time window that is triggered due to watermark advancement // - a late window that is triggered immediately due to count @@ -2298,11 +2296,10 @@ public void testUnboundedSources() throws Exception { List finalizeTracker = Lists.newArrayList(); TestCountingSource.setFinalizeTracker(finalizeTracker); - FakeWindmillServer server = new FakeWindmillServer(errorCollector); StreamingDataflowWorker worker = makeWorker( makeUnboundedSourcePipeline(), - createTestingPipelineOptions(server), + createTestingPipelineOptions(), false /* publishCounters */); worker.start(); @@ -2465,11 +2462,10 @@ public void testUnboundedSourcesDrain() throws Exception { List finalizeTracker = Lists.newArrayList(); TestCountingSource.setFinalizeTracker(finalizeTracker); - FakeWindmillServer server = new FakeWindmillServer(errorCollector); StreamingDataflowWorker worker = makeWorker( makeUnboundedSourcePipeline(), - createTestingPipelineOptions(server), + createTestingPipelineOptions(), true /* publishCounters */); worker.start(); @@ -2579,8 +2575,7 @@ public void testUnboundedSourceWorkRetry() throws Exception { List finalizeTracker = Lists.newArrayList(); TestCountingSource.setFinalizeTracker(finalizeTracker); - FakeWindmillServer server = new FakeWindmillServer(errorCollector); - StreamingDataflowWorkerOptions options = createTestingPipelineOptions(server); + StreamingDataflowWorkerOptions options = createTestingPipelineOptions(); options.setWorkerCacheMb(0); // Disable state cache so it doesn't detect retry. StreamingDataflowWorker worker = makeWorker(makeUnboundedSourcePipeline(), options, false /* publishCounters */); @@ -3094,10 +3089,10 @@ public void testExceptionInvalidatesCache() throws Exception { // 25. Read state as 42 // 26. Take counter reader checkpoint 2 // 27. CommitWork[2] (message 0:2, checkpoint 2) - FakeWindmillServer server = new FakeWindmillServer(errorCollector); + server.setExpectedExceptionCount(2); - DataflowPipelineOptions options = createTestingPipelineOptions(server); + DataflowPipelineOptions options = createTestingPipelineOptions(); options.setNumWorkers(1); DataflowPipelineDebugOptions debugOptions = options.as(DataflowPipelineDebugOptions.class); debugOptions.setUnboundedReaderMaxElements(1); @@ -3281,9 +3276,8 @@ public void testHugeCommits() throws Exception { makeDoFnInstruction(new FanoutFn(), 0, StringUtf8Coder.of()), makeSinkInstruction(StringUtf8Coder.of(), 0)); - FakeWindmillServer server = new FakeWindmillServer(errorCollector); - StreamingDataflowWorkerOptions options = createTestingPipelineOptions(server); - StreamingDataflowWorker worker = makeWorker(instructions, options, true /* publishCounters */); + StreamingDataflowWorker worker = + makeWorker(instructions, createTestingPipelineOptions(), true /* publishCounters */); worker.start(); server.whenGetWorkCalled().thenReturn(makeInput(0, TimeUnit.MILLISECONDS.toMicros(0))); @@ -3300,8 +3294,7 @@ public void testActiveWorkRefresh() throws Exception { makeDoFnInstruction(new SlowDoFn(), 0, StringUtf8Coder.of()), makeSinkInstruction(StringUtf8Coder.of(), 0)); - FakeWindmillServer server = new FakeWindmillServer(errorCollector); - StreamingDataflowWorkerOptions options = createTestingPipelineOptions(server); + StreamingDataflowWorkerOptions options = createTestingPipelineOptions(); options.setActiveWorkRefreshPeriodMillis(100); StreamingDataflowWorker worker = makeWorker(instructions, options, true /* publishCounters */); worker.start(); @@ -3324,8 +3317,7 @@ public void testActiveWorkFailure() throws Exception { makeDoFnInstruction(blockingFn, 0, StringUtf8Coder.of()), makeSinkInstruction(StringUtf8Coder.of(), 0)); - FakeWindmillServer server = new FakeWindmillServer(errorCollector); - StreamingDataflowWorkerOptions options = createTestingPipelineOptions(server); + StreamingDataflowWorkerOptions options = createTestingPipelineOptions(); options.setActiveWorkRefreshPeriodMillis(100); StreamingDataflowWorker worker = makeWorker(instructions, options, true /* publishCounters */); worker.start(); @@ -3427,8 +3419,7 @@ public void testLatencyAttributionToQueuedState() throws Exception { new FakeSlowDoFn(clock, Duration.millis(1000)), 0, StringUtf8Coder.of()), makeSinkInstruction(StringUtf8Coder.of(), 0)); - FakeWindmillServer server = new FakeWindmillServer(errorCollector); - StreamingDataflowWorkerOptions options = createTestingPipelineOptions(server); + StreamingDataflowWorkerOptions options = createTestingPipelineOptions(); options.setActiveWorkRefreshPeriodMillis(100); // A single-threaded worker processes work sequentially, leaving a second work item in state // QUEUED until the first work item is committed. @@ -3470,8 +3461,7 @@ public void testLatencyAttributionToActiveState() throws Exception { new FakeSlowDoFn(clock, Duration.millis(1000)), 0, StringUtf8Coder.of()), makeSinkInstruction(StringUtf8Coder.of(), 0)); - FakeWindmillServer server = new FakeWindmillServer(errorCollector); - StreamingDataflowWorkerOptions options = createTestingPipelineOptions(server); + StreamingDataflowWorkerOptions options = createTestingPipelineOptions(); options.setActiveWorkRefreshPeriodMillis(100); StreamingDataflowWorker worker = makeWorker( @@ -3504,8 +3494,7 @@ public void testLatencyAttributionToReadingState() throws Exception { makeDoFnInstruction(new ReadingDoFn(), 0, StringUtf8Coder.of()), makeSinkInstruction(StringUtf8Coder.of(), 0)); - FakeWindmillServer server = new FakeWindmillServer(errorCollector); - StreamingDataflowWorkerOptions options = createTestingPipelineOptions(server); + StreamingDataflowWorkerOptions options = createTestingPipelineOptions(); options.setActiveWorkRefreshPeriodMillis(100); StreamingDataflowWorker worker = makeWorker( @@ -3545,7 +3534,7 @@ public void testLatencyAttributionToCommittingState() throws Exception { makeSinkInstruction(StringUtf8Coder.of(), 0)); // Inject latency on the fake clock when the server receives a CommitWork call. - FakeWindmillServer server = new FakeWindmillServer(errorCollector); + server .whenCommitWorkCalled() .answerByDefault( @@ -3553,7 +3542,7 @@ public void testLatencyAttributionToCommittingState() throws Exception { clock.sleep(Duration.millis(1000)); return Windmill.CommitWorkResponse.getDefaultInstance(); }); - StreamingDataflowWorkerOptions options = createTestingPipelineOptions(server); + StreamingDataflowWorkerOptions options = createTestingPipelineOptions(); options.setActiveWorkRefreshPeriodMillis(100); StreamingDataflowWorker worker = makeWorker( @@ -3588,8 +3577,7 @@ public void testLatencyAttributionPopulatedInCommitRequest() throws Exception { new FakeSlowDoFn(clock, Duration.millis(dofnWaitTimeMs)), 0, StringUtf8Coder.of()), makeSinkInstruction(StringUtf8Coder.of(), 0)); - FakeWindmillServer server = new FakeWindmillServer(errorCollector); - StreamingDataflowWorkerOptions options = createTestingPipelineOptions(server); + StreamingDataflowWorkerOptions options = createTestingPipelineOptions(); options.setActiveWorkRefreshPeriodMillis(100); options.setNumberOfWorkerHarnessThreads(1); StreamingDataflowWorker worker = @@ -3641,8 +3629,7 @@ public void testDoFnLatencyBreakdownsReportedOnCommit() throws Exception { makeDoFnInstruction(new SlowDoFn(), 0, StringUtf8Coder.of()), makeSinkInstruction(StringUtf8Coder.of(), 0)); - FakeWindmillServer server = new FakeWindmillServer(errorCollector); - StreamingDataflowWorkerOptions options = createTestingPipelineOptions(server); + StreamingDataflowWorkerOptions options = createTestingPipelineOptions(); options.setActiveWorkRefreshPeriodMillis(100); StreamingDataflowWorker worker = makeWorker(instructions, options, true /* publishCounters */); worker.start(); @@ -3678,8 +3665,7 @@ public void testDoFnActiveMessageMetadataReportedOnHeartbeat() throws Exception makeDoFnInstruction(new SlowDoFn(), 0, StringUtf8Coder.of()), makeSinkInstruction(StringUtf8Coder.of(), 0)); - FakeWindmillServer server = new FakeWindmillServer(errorCollector); - StreamingDataflowWorkerOptions options = createTestingPipelineOptions(server); + StreamingDataflowWorkerOptions options = createTestingPipelineOptions(); options.setActiveWorkRefreshPeriodMillis(10); StreamingDataflowWorker worker = makeWorker(instructions, options, true /* publishCounters */); worker.start(); @@ -3715,12 +3701,11 @@ public void testLimitOnOutputBundleSize() throws Exception { final int numMessagesInCustomSourceShard = 100000; // 100K input messages. final int inflatedSizePerMessage = 10000; // x10k => 1GB total output size. - FakeWindmillServer server = new FakeWindmillServer(errorCollector); StreamingDataflowWorker worker = makeWorker( makeUnboundedSourcePipeline( numMessagesInCustomSourceShard, new InflateDoFn(inflatedSizePerMessage)), - createTestingPipelineOptions(server), + createTestingPipelineOptions(), false /* publishCounters */); worker.start(); @@ -3802,9 +3787,8 @@ public void testLimitOnOutputBundleSizeWithMultipleSinks() throws Exception { 1, GlobalWindow.Coder.INSTANCE)); - FakeWindmillServer server = new FakeWindmillServer(errorCollector); StreamingDataflowWorker worker = - makeWorker(instructions, createTestingPipelineOptions(server), true /* publishCounters */); + makeWorker(instructions, createTestingPipelineOptions(), true /* publishCounters */); worker.start(); // Test new key. @@ -3870,8 +3854,7 @@ public void testStuckCommit() throws Exception { makeSourceInstruction(StringUtf8Coder.of()), makeSinkInstruction(StringUtf8Coder.of(), 0)); - FakeWindmillServer server = new FakeWindmillServer(errorCollector); - StreamingDataflowWorkerOptions options = createTestingPipelineOptions(server); + StreamingDataflowWorkerOptions options = createTestingPipelineOptions(); options.setStuckCommitDurationMillis(2000); StreamingDataflowWorker worker = makeWorker(instructions, options, true /* publishCounters */); worker.start(); @@ -3904,14 +3887,12 @@ public void testStuckCommit() throws Exception { removeDynamicFields(result.get(1L))); } - private void runNumCommitThreadsTest(int configNumCommitThreads, int expectedNumCommitThreads) - throws Exception { + private void runNumCommitThreadsTest(int configNumCommitThreads, int expectedNumCommitThreads) { List instructions = Arrays.asList( makeSourceInstruction(StringUtf8Coder.of()), makeSinkInstruction(StringUtf8Coder.of(), 0)); - FakeWindmillServer server = new FakeWindmillServer(errorCollector); - StreamingDataflowWorkerOptions options = createTestingPipelineOptions(server); + StreamingDataflowWorkerOptions options = createTestingPipelineOptions(); options.setWindmillServiceCommitThreads(configNumCommitThreads); StreamingDataflowWorker worker = makeWorker(instructions, options, true /* publishCounters */); worker.start(); @@ -3920,7 +3901,7 @@ private void runNumCommitThreadsTest(int configNumCommitThreads, int expectedNum } @Test - public void testDefaultNumCommitThreads() throws Exception { + public void testDefaultNumCommitThreads() { if (streamingEngine) { runNumCommitThreadsTest(1, 1); runNumCommitThreadsTest(2, 2); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java index 5d17795b28fd..515beba0c88d 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java @@ -33,7 +33,7 @@ import java.util.function.Consumer; import java.util.stream.Collectors; import javax.annotation.Nullable; -import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillMetadataServiceV1Alpha1Grpc; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataResponse; @@ -62,14 +62,14 @@ @RunWith(JUnit4.class) public class GrpcGetWorkerMetadataStreamTest { - @Rule public transient Timeout globalTimeout = Timeout.seconds(600); private static final String IPV6_ADDRESS_1 = "2001:db8:0000:bac5:0000:0000:fed0:81a2"; private static final String IPV6_ADDRESS_2 = "2001:db8:0000:bac5:0000:0000:fed0:82a3"; + private static final String AUTHENTICATING_SERVICE = "test.googleapis.com"; private static final List DIRECT_PATH_ENDPOINTS = Lists.newArrayList( WorkerMetadataResponse.Endpoint.newBuilder() .setDirectEndpoint(IPV6_ADDRESS_1) - .setWorkerToken("worker_token") + .setBackendWorkerToken("worker_token") .build()); private static final Map GLOBAL_DATA_ENDPOINTS = Maps.newHashMap(); @@ -83,6 +83,7 @@ public class GrpcGetWorkerMetadataStreamTest { @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); private final Set> streamRegistry = new HashSet<>(); + @Rule public transient Timeout globalTimeout = Timeout.seconds(600); private ManagedChannel inProcessChannel; private GrpcGetWorkerMetadataStream stream; @@ -93,8 +94,8 @@ private GrpcGetWorkerMetadataStream getWorkerMetadataTestStream( serviceRegistry.addService(getWorkerMetadataTestStub); return GrpcGetWorkerMetadataStream.create( responseObserver -> - CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel) - .getWorkerMetadataStream(responseObserver), + CloudWindmillMetadataServiceV1Alpha1Grpc.newStub(inProcessChannel) + .getWorkerMetadata(responseObserver), FluentBackoff.DEFAULT.backoff(), StreamObserverFactory.direct(DEFAULT_STREAM_RPC_DEADLINE_SECONDS * 2, 1), streamRegistry, @@ -123,7 +124,7 @@ public void setUp() throws IOException { "global_data", WorkerMetadataResponse.Endpoint.newBuilder() .setDirectEndpoint(IPV6_ADDRESS_1) - .setWorkerToken("worker_token") + .setBackendWorkerToken("worker_token") .build()); } @@ -139,6 +140,7 @@ public void testGetWorkerMetadata() { .setMetadataVersion(1) .addAllWorkEndpoints(DIRECT_PATH_ENDPOINTS) .putAllGlobalDataEndpoints(GLOBAL_DATA_ENDPOINTS) + .setExternalEndpoint(AUTHENTICATING_SERVICE) .build(); TestWindmillEndpointsConsumer testWindmillEndpointsConsumer = new TestWindmillEndpointsConsumer(); @@ -153,7 +155,9 @@ public void testGetWorkerMetadata() { assertThat(testWindmillEndpointsConsumer.windmillEndpoints) .containsExactlyElementsIn( DIRECT_PATH_ENDPOINTS.stream() - .map(WindmillEndpoints.Endpoint::from) + .map( + endpointProto -> + WindmillEndpoints.Endpoint.from(endpointProto, AUTHENTICATING_SERVICE)) .collect(Collectors.toList())); } @@ -164,6 +168,7 @@ public void testGetWorkerMetadata_consumesSubsequentResponseMetadata() { .setMetadataVersion(1) .addAllWorkEndpoints(DIRECT_PATH_ENDPOINTS) .putAllGlobalDataEndpoints(GLOBAL_DATA_ENDPOINTS) + .setExternalEndpoint(AUTHENTICATING_SERVICE) .build(); TestWindmillEndpointsConsumer testWindmillEndpointsConsumer = Mockito.spy(new TestWindmillEndpointsConsumer()); @@ -187,6 +192,7 @@ public void testGetWorkerMetadata_consumesSubsequentResponseMetadata() { .setMetadataVersion(initialResponse.getMetadataVersion() + 1) .addAllWorkEndpoints(newDirectPathEndpoints) .putAllGlobalDataEndpoints(newGlobalDataEndpoints) + .setExternalEndpoint(AUTHENTICATING_SERVICE) .build(); testStub.injectWorkerMetadata(newWorkMetadataResponse); @@ -196,7 +202,9 @@ public void testGetWorkerMetadata_consumesSubsequentResponseMetadata() { assertThat(testWindmillEndpointsConsumer.windmillEndpoints) .containsExactlyElementsIn( newDirectPathEndpoints.stream() - .map(WindmillEndpoints.Endpoint::from) + .map( + endpointProto -> + WindmillEndpoints.Endpoint.from(endpointProto, AUTHENTICATING_SERVICE)) .collect(Collectors.toList())); } @@ -207,6 +215,7 @@ public void testGetWorkerMetadata_doesNotConsumeResponseIfMetadataStale() { .setMetadataVersion(2) .addAllWorkEndpoints(DIRECT_PATH_ENDPOINTS) .putAllGlobalDataEndpoints(GLOBAL_DATA_ENDPOINTS) + .setExternalEndpoint(AUTHENTICATING_SERVICE) .build(); TestWindmillEndpointsConsumer testWindmillEndpointsConsumer = @@ -268,7 +277,8 @@ public void testSendHealthCheck() { } private static class GetWorkerMetadataTestStub - extends CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1ImplBase { + extends CloudWindmillMetadataServiceV1Alpha1Grpc + .CloudWindmillMetadataServiceV1Alpha1ImplBase { private final TestGetWorkMetadataRequestObserver requestObserver; private @Nullable StreamObserver responseObserver; @@ -277,7 +287,7 @@ private GetWorkerMetadataTestStub(TestGetWorkMetadataRequestObserver requestObse } @Override - public StreamObserver getWorkerMetadataStream( + public StreamObserver getWorkerMetadata( StreamObserver responseObserver) { if (this.responseObserver == null) { this.responseObserver = responseObserver; diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java index 15610462e014..37dc7eff917a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java @@ -73,6 +73,8 @@ import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillChannelFactory; +import org.apache.beam.runners.dataflow.worker.windmill.testing.FakeWindmillStubFactory; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.CallOptions; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Channel; @@ -87,6 +89,7 @@ import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry; import org.checkerframework.checker.nullness.qual.Nullable; import org.hamcrest.Matchers; @@ -109,10 +112,13 @@ }) public class GrpcWindmillServerTest { @Rule public transient Timeout globalTimeout = Timeout.seconds(600); + @Rule public GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + @Rule public ErrorCollector errorCollector = new ErrorCollector(); + private static final Logger LOG = LoggerFactory.getLogger(GrpcWindmillServerTest.class); private static final int STREAM_CHUNK_SIZE = 2 << 20; + private final long clientId = 10L; private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); - @Rule public ErrorCollector errorCollector = new ErrorCollector(); private Server server; private GrpcWindmillServer client; private int remainingErrors = 20; @@ -128,7 +134,13 @@ public void setUp() throws Exception { .build() .start(); - this.client = GrpcWindmillServer.newTestInstance(name, new ArrayList<>()); + this.client = + GrpcWindmillServer.newTestInstance( + name, + new ArrayList<>(), + clientId, + new FakeWindmillStubFactory( + () -> grpcCleanup.register(WindmillChannelFactory.inProcessChannel(name)))); } @After @@ -197,7 +209,9 @@ public ClientCall interceptCall( .build(), testInterceptor); - this.client = GrpcWindmillServer.newApplianceTestInstance(inprocessChannel); + this.client = + GrpcWindmillServer.newApplianceTestInstance( + inprocessChannel, new FakeWindmillStubFactory(() -> inprocessChannel)); Windmill.GetWorkResponse response1 = client.getWork(GetWorkRequest.getDefaultInstance()); Windmill.GetWorkResponse response2 = client.getWork(GetWorkRequest.getDefaultInstance()); @@ -346,6 +360,7 @@ public void onNext(StreamingGetDataRequest chunk) { .setJobId("job") .setProjectId("project") .setWorkerId("worker") + .setClientId(clientId) .build())); sawHeader = true; } else { @@ -555,6 +570,7 @@ public void onNext(StreamingCommitWorkRequest request) { .setJobId("job") .setProjectId("project") .setWorkerId("worker") + .setClientId(clientId) .build())); sawHeader = true; LOG.info("Received header"); @@ -839,6 +855,7 @@ public void onNext(StreamingGetDataRequest chunk) { .setJobId("job") .setProjectId("project") .setWorkerId("worker") + .setClientId(clientId) .build())); sawHeader = true; } else { @@ -921,7 +938,10 @@ public void testStreamingGetDataHeartbeatsAsHeartbeatRequests() throws Exception this.client = GrpcWindmillServer.newTestInstance( "TestServer", - Collections.singletonList("streaming_engine_send_new_heartbeat_requests")); + Collections.singletonList("streaming_engine_send_new_heartbeat_requests"), + clientId, + new FakeWindmillStubFactory( + () -> WindmillChannelFactory.inProcessChannel("TestServer"))); // This server records the heartbeats observed but doesn't respond. final List receivedHeartbeats = new ArrayList<>(); @@ -945,6 +965,7 @@ public void onNext(StreamingGetDataRequest chunk) { .setJobId("job") .setProjectId("project") .setWorkerId("worker") + .setClientId(clientId) .build())); sawHeader = true; } else { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClientTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClientTest.java index 4831726c49e0..f755f0333387 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClientTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClientTest.java @@ -37,7 +37,7 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; import javax.annotation.Nullable; -import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillMetadataServiceV1Alpha1Grpc; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataRequest; @@ -46,6 +46,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.WindmillServiceAddress; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillChannelFactory; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillStubFactory; +import org.apache.beam.runners.dataflow.worker.windmill.testing.FakeWindmillStubFactory; import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemProcessor; import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetDistributor; @@ -73,7 +74,6 @@ @RunWith(JUnit4.class) public class StreamingEngineClientTest { - @Rule public transient Timeout globalTimeout = Timeout.seconds(600); private static final WindmillServiceAddress DEFAULT_WINDMILL_SERVICE_ADDRESS = WindmillServiceAddress.create(HostAndPort.fromParts(WindmillChannelFactory.LOCALHOST, 443)); private static final ImmutableMap DEFAULT = @@ -95,24 +95,25 @@ public class StreamingEngineClientTest { @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); private final Set channels = new HashSet<>(); private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); - private final GrpcWindmillStreamFactory streamFactory = spy(GrpcWindmillStreamFactory.of(JOB_HEADER).build()); private final WindmillStubFactory stubFactory = - WindmillStubFactory.inProcessStubFactory( - "StreamingEngineClientTest", - name -> { + new FakeWindmillStubFactory( + () -> { ManagedChannel channel = - grpcCleanup.register(WindmillChannelFactory.inProcessChannel(name)); + grpcCleanup.register( + WindmillChannelFactory.inProcessChannel("StreamingEngineClientTest")); channels.add(channel); return channel; }); private final GrpcDispatcherClient dispatcherClient = - GrpcDispatcherClient.forTesting(stubFactory, new ArrayList<>(), new HashSet<>()); + GrpcDispatcherClient.forTesting( + stubFactory, new ArrayList<>(), new ArrayList<>(), new HashSet<>()); private final GetWorkBudgetDistributor getWorkBudgetDistributor = spy(new TestGetWorkBudgetDistributor()); private final AtomicReference connections = new AtomicReference<>(StreamingEngineConnectionState.EMPTY); + @Rule public transient Timeout globalTimeout = Timeout.seconds(600); private Server fakeStreamingEngineServer; private CountDownLatch getWorkerMetadataReady; private GetWorkerMetadataTestStub fakeGetWorkerMetadataStub; @@ -140,7 +141,7 @@ private static GetWorkRequest getWorkRequest(long items, long bytes) { } private static WorkerMetadataResponse.Endpoint metadataResponseEndpoint(String workerToken) { - return WorkerMetadataResponse.Endpoint.newBuilder().setWorkerToken(workerToken).build(); + return WorkerMetadataResponse.Endpoint.newBuilder().setBackendWorkerToken(workerToken).build(); } @Before @@ -269,16 +270,22 @@ public void testOnNewWorkerMetadata_correctlyRemovesStaleWindmillServers() WorkerMetadataResponse.newBuilder() .setMetadataVersion(1) .addWorkEndpoints( - WorkerMetadataResponse.Endpoint.newBuilder().setWorkerToken(workerToken).build()) + WorkerMetadataResponse.Endpoint.newBuilder() + .setBackendWorkerToken(workerToken) + .build()) .addWorkEndpoints( - WorkerMetadataResponse.Endpoint.newBuilder().setWorkerToken(workerToken2).build()) + WorkerMetadataResponse.Endpoint.newBuilder() + .setBackendWorkerToken(workerToken2) + .build()) .putAllGlobalDataEndpoints(DEFAULT) .build(); WorkerMetadataResponse secondWorkerMetadata = WorkerMetadataResponse.newBuilder() .setMetadataVersion(2) .addWorkEndpoints( - WorkerMetadataResponse.Endpoint.newBuilder().setWorkerToken(workerToken3).build()) + WorkerMetadataResponse.Endpoint.newBuilder() + .setBackendWorkerToken(workerToken3) + .build()) .putAllGlobalDataEndpoints(DEFAULT) .build(); @@ -315,21 +322,27 @@ public void testOnNewWorkerMetadata_redistributesBudget() throws InterruptedExce WorkerMetadataResponse.newBuilder() .setMetadataVersion(1) .addWorkEndpoints( - WorkerMetadataResponse.Endpoint.newBuilder().setWorkerToken(workerToken).build()) + WorkerMetadataResponse.Endpoint.newBuilder() + .setBackendWorkerToken(workerToken) + .build()) .putAllGlobalDataEndpoints(DEFAULT) .build(); WorkerMetadataResponse secondWorkerMetadata = WorkerMetadataResponse.newBuilder() .setMetadataVersion(2) .addWorkEndpoints( - WorkerMetadataResponse.Endpoint.newBuilder().setWorkerToken(workerToken2).build()) + WorkerMetadataResponse.Endpoint.newBuilder() + .setBackendWorkerToken(workerToken2) + .build()) .putAllGlobalDataEndpoints(DEFAULT) .build(); WorkerMetadataResponse thirdWorkerMetadata = WorkerMetadataResponse.newBuilder() .setMetadataVersion(3) .addWorkEndpoints( - WorkerMetadataResponse.Endpoint.newBuilder().setWorkerToken(workerToken3).build()) + WorkerMetadataResponse.Endpoint.newBuilder() + .setBackendWorkerToken(workerToken3) + .build()) .putAllGlobalDataEndpoints(DEFAULT) .build(); @@ -362,7 +375,8 @@ private StreamingEngineConnectionState waitForWorkerMetadataToBeConsumed( } private static class GetWorkerMetadataTestStub - extends CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1ImplBase { + extends CloudWindmillMetadataServiceV1Alpha1Grpc + .CloudWindmillMetadataServiceV1Alpha1ImplBase { private static final WorkerMetadataResponse CLOSE_ALL_STREAMS = WorkerMetadataResponse.newBuilder().setMetadataVersion(100).build(); private final CountDownLatch ready; @@ -373,7 +387,7 @@ private GetWorkerMetadataTestStub(CountDownLatch ready) { } @Override - public StreamObserver getWorkerMetadataStream( + public StreamObserver getWorkerMetadata( StreamObserver responseObserver) { if (this.responseObserver == null) { ready.countDown(); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/testing/FakeWindmillStubFactory.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/testing/FakeWindmillStubFactory.java new file mode 100644 index 000000000000..3dd40e5d5c73 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/testing/FakeWindmillStubFactory.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.testing; + +import java.util.function.Supplier; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillMetadataServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillServiceAddress; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillStubFactory; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Channel; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; + +@VisibleForTesting +public final class FakeWindmillStubFactory implements WindmillStubFactory { + private final Supplier channelFactory; + + public FakeWindmillStubFactory(Supplier channelFactory) { + this.channelFactory = channelFactory; + } + + @Override + public CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub + createWindmillServiceStub(WindmillServiceAddress serviceAddress) { + return CloudWindmillServiceV1Alpha1Grpc.newStub(channelFactory.get()); + } + + @Override + public CloudWindmillMetadataServiceV1Alpha1Grpc.CloudWindmillMetadataServiceV1Alpha1Stub + createWindmillMetadataServiceStub(WindmillServiceAddress serviceAddress) { + return CloudWindmillMetadataServiceV1Alpha1Grpc.newStub(channelFactory.get()); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto index 0c824ca301b3..4677ff9dcc9a 100644 --- a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto +++ b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto @@ -847,6 +847,11 @@ message JobHeader { optional string project_id = 2; // Worker id is meant for logging only. Do not rely on it for other decisions. optional string worker_id = 3; + optional fixed64 client_id = 4; + optional string region_id = 5; + // Used by the user worker to communicate to a specific windmill worker. This + // is initially passed to the user worker via GetWorkerMetadata. + optional string backend_worker_token = 6; } message StreamingCommitRequestChunk { @@ -902,14 +907,19 @@ message WorkerMetadataResponse { message Endpoint { // IPv6 address of a streaming engine windmill worker. optional string direct_endpoint = 1; - optional string worker_token = 2; + optional string backend_worker_token = 2; + optional int64 port = 3; } + repeated Endpoint work_endpoints = 2; // Maps from GlobalData tag to the endpoint that should be used for GetData // calls to retrieve that global data. map global_data_endpoints = 3; + // Used to set gRPC authority. + optional string external_endpoint = 5; + reserved 4; } diff --git a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill_service.proto b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill_service.proto index d9183e54e0dd..101bae170db5 100644 --- a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill_service.proto +++ b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill_service.proto @@ -33,10 +33,6 @@ service CloudWindmillServiceV1Alpha1 { rpc GetWorkStream(stream .windmill.StreamingGetWorkRequest) returns (stream .windmill.StreamingGetWorkResponseChunk); - // Gets worker metadata. Response is a stream. - rpc GetWorkerMetadataStream(stream .windmill.WorkerMetadataRequest) - returns (stream .windmill.WorkerMetadataResponse); - // Gets data from Windmill. rpc GetData(.windmill.GetDataRequest) returns(.windmill.GetDataResponse); @@ -52,3 +48,9 @@ service CloudWindmillServiceV1Alpha1 { rpc CommitWorkStream(stream .windmill.StreamingCommitWorkRequest) returns (stream .windmill.StreamingCommitResponse); } + +service CloudWindmillMetadataServiceV1Alpha1 { + // Gets worker metadata. Response is a stream. + rpc GetWorkerMetadata(stream.windmill.WorkerMetadataRequest) + returns (stream.windmill.WorkerMetadataResponse); +}