diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java index 9389bf397e88..c8503d95781e 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java @@ -17,26 +17,27 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs; -import java.net.Inet6Address; -import java.net.InetSocketAddress; import java.util.concurrent.TimeUnit; import javax.net.ssl.SSLException; import org.apache.beam.runners.dataflow.worker.windmill.WindmillServiceAddress; import org.apache.beam.runners.dataflow.worker.windmill.WindmillServiceAddress.AuthenticatedGcpServiceAddress; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Channel; -import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ForwardingChannelBuilder2; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ManagedChannel; -import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.alts.AltsChannelBuilder; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.alts.AltsChannelCredentials; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.inprocess.InProcessChannelBuilder; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.netty.GrpcSslContexts; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.netty.NegotiationType; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.netty.NettyChannelBuilder; +import org.apache.beam.vendor.grpc.v1p69p0.io.netty.handler.ssl.SslContext; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; /** Utility class used to create different RPC Channels. */ public final class WindmillChannelFactory { public static final String LOCALHOST = "localhost"; private static final int MAX_REMOTE_TRACE_EVENTS = 100; + // 10MiB. + private static final int WINDMILL_MAX_FLOW_CONTROL_WINDOW = + NettyChannelBuilder.DEFAULT_FLOW_CONTROL_WINDOW * 10; private WindmillChannelFactory() {} @@ -69,13 +70,14 @@ public static ManagedChannel remoteChannel( } } - static ManagedChannel remoteDirectChannel( + private static ManagedChannel remoteDirectChannel( AuthenticatedGcpServiceAddress authenticatedGcpServiceAddress, int windmillServiceRpcChannelTimeoutSec) { return withDefaultChannelOptions( - AltsChannelBuilder.forAddress( + NettyChannelBuilder.forAddress( authenticatedGcpServiceAddress.gcpServiceAddress().getHost(), - authenticatedGcpServiceAddress.gcpServiceAddress().getPort()) + authenticatedGcpServiceAddress.gcpServiceAddress().getPort(), + new AltsChannelCredentials.Builder().build()) .overrideAuthority(authenticatedGcpServiceAddress.authenticatingService()), windmillServiceRpcChannelTimeoutSec) .build(); @@ -83,41 +85,27 @@ static ManagedChannel remoteDirectChannel( public static ManagedChannel remoteChannel( HostAndPort endpoint, int windmillServiceRpcChannelTimeoutSec) { - try { - return createRemoteChannel( - NettyChannelBuilder.forAddress(endpoint.getHost(), endpoint.getPort()), - windmillServiceRpcChannelTimeoutSec); - } catch (SSLException sslException) { - throw new WindmillChannelCreationException(endpoint, sslException); - } + return withDefaultChannelOptions( + NettyChannelBuilder.forAddress(endpoint.getHost(), endpoint.getPort()), + windmillServiceRpcChannelTimeoutSec) + .negotiationType(NegotiationType.TLS) + .sslContext(dataflowGrpcSslContext(endpoint)) + .build(); } - public static Channel remoteChannel( - Inet6Address directEndpoint, int port, int windmillServiceRpcChannelTimeoutSec) { + @SuppressWarnings("nullness") + private static SslContext dataflowGrpcSslContext(HostAndPort endpoint) { try { - return createRemoteChannel( - NettyChannelBuilder.forAddress(new InetSocketAddress(directEndpoint, port)), - windmillServiceRpcChannelTimeoutSec); + // Set ciphers(null) to not use GCM, which is disabled for Dataflow + // due to it being horribly slow. + return GrpcSslContexts.forClient().ciphers(null).build(); } catch (SSLException sslException) { - throw new WindmillChannelCreationException(directEndpoint.toString(), sslException); + throw new WindmillChannelCreationException(endpoint, sslException); } } - @SuppressWarnings("nullness") - private static ManagedChannel createRemoteChannel( - NettyChannelBuilder channelBuilder, int windmillServiceRpcChannelTimeoutSec) - throws SSLException { - return withDefaultChannelOptions(channelBuilder, windmillServiceRpcChannelTimeoutSec) - .flowControlWindow(10 * 1024 * 1024) - .negotiationType(NegotiationType.TLS) - // Set ciphers(null) to not use GCM, which is disabled for Dataflow - // due to it being horribly slow. - .sslContext(GrpcSslContexts.forClient().ciphers(null).build()) - .build(); - } - - private static > T withDefaultChannelOptions( - T channelBuilder, int windmillServiceRpcChannelTimeoutSec) { + private static NettyChannelBuilder withDefaultChannelOptions( + NettyChannelBuilder channelBuilder, int windmillServiceRpcChannelTimeoutSec) { if (windmillServiceRpcChannelTimeoutSec > 0) { channelBuilder .keepAliveTime(windmillServiceRpcChannelTimeoutSec, TimeUnit.SECONDS) @@ -128,10 +116,12 @@ private static > T withDefaultChannelOpti return channelBuilder .maxInboundMessageSize(Integer.MAX_VALUE) .maxTraceEvents(MAX_REMOTE_TRACE_EVENTS) - .maxInboundMetadataSize(1024 * 1024); + // 1MiB + .maxInboundMetadataSize(1024 * 1024) + .flowControlWindow(WINDMILL_MAX_FLOW_CONTROL_WINDOW); } - public static class WindmillChannelCreationException extends IllegalStateException { + private static class WindmillChannelCreationException extends IllegalStateException { private WindmillChannelCreationException(HostAndPort endpoint, SSLException sourceException) { super( String.format( @@ -139,12 +129,5 @@ private WindmillChannelCreationException(HostAndPort endpoint, SSLException sour endpoint.getHost(), endpoint.getPort()), sourceException); } - - WindmillChannelCreationException(String directEndpoint, Throwable sourceException) { - super( - String.format( - "Exception thrown when trying to create channel to endpoint={%s}", directEndpoint), - sourceException); - } } }