diff --git a/core/trino-main/src/main/java/io/trino/failuredetector/FailureDetectorModule.java b/core/trino-main/src/main/java/io/trino/failuredetector/FailureDetectorModule.java index 259280f11093..37f0d031b61c 100644 --- a/core/trino-main/src/main/java/io/trino/failuredetector/FailureDetectorModule.java +++ b/core/trino-main/src/main/java/io/trino/failuredetector/FailureDetectorModule.java @@ -14,22 +14,22 @@ package io.trino.failuredetector; import com.google.inject.Binder; -import com.google.inject.Module; import com.google.inject.Scopes; +import io.airlift.configuration.AbstractConfigurationAwareModule; import org.weakref.jmx.guice.ExportBinder; import static io.airlift.configuration.ConfigBinder.configBinder; -import static io.airlift.http.client.HttpClientBinder.httpClientBinder; +import static io.trino.server.InternalCommunicationHttpClientModule.internalHttpClientModule; public class FailureDetectorModule - implements Module + extends AbstractConfigurationAwareModule { @Override - public void configure(Binder binder) + protected void setup(Binder binder) { - httpClientBinder(binder) - .bindHttpClient("failure-detector", ForFailureDetector.class) - .withTracing(); + install(internalHttpClientModule("failure-detector", ForFailureDetector.class) + .withTracing() + .build()); configBinder(binder).bindConfig(FailureDetectorConfig.class); diff --git a/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java b/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java index 7a4076566372..e50401d42a08 100644 --- a/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java +++ b/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java @@ -129,12 +129,12 @@ import static io.airlift.configuration.ConditionalModule.conditionalModule; import static io.airlift.configuration.ConfigBinder.configBinder; import static io.airlift.discovery.client.DiscoveryBinder.discoveryBinder; -import static io.airlift.http.client.HttpClientBinder.httpClientBinder; import static io.airlift.jaxrs.JaxrsBinder.jaxrsBinder; import static io.airlift.json.JsonCodecBinder.jsonCodecBinder; import static io.airlift.units.DataSize.Unit.MEGABYTE; import static io.trino.execution.scheduler.NodeSchedulerConfig.NodeAllocatorType.BIN_PACKING; import static io.trino.execution.scheduler.NodeSchedulerConfig.NodeAllocatorType.FIXED_COUNT; +import static io.trino.server.InternalCommunicationHttpClientModule.internalHttpClientModule; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; import static java.util.concurrent.Executors.newSingleThreadScheduledExecutor; @@ -167,10 +167,10 @@ protected void setup(Binder binder) }); // failure detector - binder.install(new FailureDetectorModule()); + install(new FailureDetectorModule()); jaxrsBinder(binder).bind(NodeResource.class); jaxrsBinder(binder).bind(WorkerResource.class); - httpClientBinder(binder).bindHttpClient("workerInfo", ForWorkerInfo.class); + install(internalHttpClientModule("workerInfo", ForWorkerInfo.class).build()); // query monitor jsonCodecBinder(binder).bindJsonCodec(ExecutionFailureInfo.class); @@ -206,12 +206,12 @@ protected void setup(Binder binder) // cluster memory manager binder.bind(ClusterMemoryManager.class).in(Scopes.SINGLETON); binder.bind(ClusterMemoryPoolManager.class).to(ClusterMemoryManager.class).in(Scopes.SINGLETON); - httpClientBinder(binder).bindHttpClient("memoryManager", ForMemoryManager.class) + install(internalHttpClientModule("memoryManager", ForMemoryManager.class) .withTracing() .withConfigDefaults(config -> { config.setIdleTimeout(new Duration(30, SECONDS)); config.setRequestTimeout(new Duration(10, SECONDS)); - }); + }).build()); bindLowMemoryTaskKiller(LowMemoryTaskKillerPolicy.NONE, NoneLowMemoryKiller.class); bindLowMemoryTaskKiller(LowMemoryTaskKillerPolicy.TOTAL_RESERVATION_ON_BLOCKED_NODES, TotalReservationOnBlockedNodesTaskLowMemoryKiller.class); @@ -293,14 +293,14 @@ protected void setup(Binder binder) binder.bind(RemoteTaskStats.class).in(Scopes.SINGLETON); newExporter(binder).export(RemoteTaskStats.class).withGeneratedName(); - httpClientBinder(binder).bindHttpClient("scheduler", ForScheduler.class) + install(internalHttpClientModule("scheduler", ForScheduler.class) .withTracing() .withFilter(GenerateTraceTokenRequestFilter.class) .withConfigDefaults(config -> { config.setIdleTimeout(new Duration(30, SECONDS)); config.setRequestTimeout(new Duration(10, SECONDS)); config.setMaxConnectionsPerServer(250); - }); + }).build()); binder.bind(ScheduledExecutorService.class).annotatedWith(ForScheduler.class) .toInstance(newSingleThreadScheduledExecutor(threadsNamed("stage-scheduler"))); diff --git a/core/trino-main/src/main/java/io/trino/server/InternalCommunicationHttpClientModule.java b/core/trino-main/src/main/java/io/trino/server/InternalCommunicationHttpClientModule.java new file mode 100644 index 000000000000..632231c648e6 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/server/InternalCommunicationHttpClientModule.java @@ -0,0 +1,142 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.server; + +import com.google.common.collect.ImmutableList; +import com.google.inject.Binder; +import io.airlift.configuration.AbstractConfigurationAwareModule; +import io.airlift.http.client.HttpClientBinder.HttpClientBindingBuilder; +import io.airlift.http.client.HttpClientConfig; +import io.airlift.http.client.HttpRequestFilter; + +import java.lang.annotation.Annotation; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Consumer; + +import static io.airlift.http.client.HttpClientBinder.httpClientBinder; +import static java.util.Objects.requireNonNull; + +public class InternalCommunicationHttpClientModule + extends AbstractConfigurationAwareModule +{ + private final String clientName; + private final Class annotation; + private final boolean withTracing; + private final Consumer configDefaults; + private final List> filters; + + private InternalCommunicationHttpClientModule( + String clientName, + Class annotation, + boolean withTracing, + Consumer configDefaults, + List> filters) + { + this.clientName = requireNonNull(clientName, "clientName is null"); + this.annotation = requireNonNull(annotation, "annotation is null"); + this.withTracing = withTracing; + this.configDefaults = requireNonNull(configDefaults, "configDefaults is null"); + this.filters = ImmutableList.copyOf(requireNonNull(filters, "filters is null")); + } + + @Override + protected void setup(Binder binder) + { + HttpClientBindingBuilder httpClientBindingBuilder = httpClientBinder(binder).bindHttpClient(clientName, annotation); + InternalCommunicationConfig internalCommunicationConfig = buildConfigObject(InternalCommunicationConfig.class); + httpClientBindingBuilder.withConfigDefaults(httpConfig -> { + configureClient(httpConfig, internalCommunicationConfig); + configDefaults.accept(httpConfig); + }); + + httpClientBindingBuilder.addFilterBinding().to(InternalAuthenticationManager.class); + + if (withTracing) { + httpClientBindingBuilder.withTracing(); + } + + filters.forEach(httpClientBindingBuilder::withFilter); + } + + static void configureClient(HttpClientConfig httpConfig, InternalCommunicationConfig internalCommunicationConfig) + { + httpConfig.setHttp2Enabled(internalCommunicationConfig.isHttp2Enabled()); + if (internalCommunicationConfig.isHttpsRequired() && internalCommunicationConfig.getKeyStorePath() == null && internalCommunicationConfig.getTrustStorePath() == null) { + configureClientForAutomaticHttps(httpConfig, internalCommunicationConfig); + } + else { + configureClientForManualHttps(httpConfig, internalCommunicationConfig); + } + } + + private static void configureClientForAutomaticHttps(HttpClientConfig httpConfig, InternalCommunicationConfig internalCommunicationConfig) + { + String sharedSecret = internalCommunicationConfig.getSharedSecret() + .orElseThrow(() -> new IllegalArgumentException("Internal shared secret must be set when internal HTTPS is enabled")); + httpConfig.setAutomaticHttpsSharedSecret(sharedSecret); + } + + private static void configureClientForManualHttps(HttpClientConfig httpConfig, InternalCommunicationConfig internalCommunicationConfig) + { + httpConfig.setKeyStorePath(internalCommunicationConfig.getKeyStorePath()); + httpConfig.setKeyStorePassword(internalCommunicationConfig.getKeyStorePassword()); + httpConfig.setTrustStorePath(internalCommunicationConfig.getTrustStorePath()); + httpConfig.setTrustStorePassword(internalCommunicationConfig.getTrustStorePassword()); + httpConfig.setAutomaticHttpsSharedSecret(null); + } + + public static class Builder + { + private final String clientName; + private final Class annotation; + private boolean withTracing; + private Consumer configDefaults = config -> {}; + private final List> filters = new ArrayList<>(); + + private Builder(String clientName, Class annotation) + { + this.clientName = requireNonNull(clientName, "clientName is null"); + this.annotation = requireNonNull(annotation, "annotation is null"); + } + + public Builder withTracing() + { + this.withTracing = true; + return this; + } + + public Builder withConfigDefaults(Consumer configDefaults) + { + this.configDefaults = requireNonNull(configDefaults, "configDefaults is null"); + return this; + } + + public Builder withFilter(Class requestFilter) + { + this.filters.add(requestFilter); + return this; + } + + public InternalCommunicationHttpClientModule build() + { + return new InternalCommunicationHttpClientModule(clientName, annotation, withTracing, configDefaults, filters); + } + } + + public static InternalCommunicationHttpClientModule.Builder internalHttpClientModule(String clientName, Class annotation) + { + return new Builder(clientName, annotation); + } +} diff --git a/core/trino-main/src/main/java/io/trino/server/InternalCommunicationModule.java b/core/trino-main/src/main/java/io/trino/server/InternalCommunicationModule.java index d71bd5dfee5b..e2194872ce10 100644 --- a/core/trino-main/src/main/java/io/trino/server/InternalCommunicationModule.java +++ b/core/trino-main/src/main/java/io/trino/server/InternalCommunicationModule.java @@ -14,9 +14,9 @@ package io.trino.server; import com.google.inject.Binder; +import com.google.inject.multibindings.Multibinder; import io.airlift.configuration.AbstractConfigurationAwareModule; import io.airlift.discovery.client.ForDiscoveryClient; -import io.airlift.http.client.HttpClientConfig; import io.airlift.http.client.HttpRequestFilter; import io.airlift.http.client.Request; import io.airlift.http.server.HttpsConfig; @@ -30,7 +30,6 @@ import static com.google.inject.multibindings.Multibinder.newSetBinder; import static io.airlift.configuration.ConfigBinder.configBinder; -import static io.airlift.http.client.HttpClientBinder.httpClientBinder; import static io.airlift.node.AddressToHostname.encodeAddressAsHostname; import static io.airlift.node.NodeConfig.AddressSource.IP_ENCODED_AS_HOSTNAME; @@ -40,59 +39,47 @@ public class InternalCommunicationModule @Override protected void setup(Binder binder) { - // Set defaults for all HttpClients in the same guice context - // so in case of any additions or alternations here an update in: - // io.trino.server.security.jwt.JwtAuthenticatorSupportModule.JwkModule.configure - // and - // io.trino.server.security.oauth2.OAuth2ServiceModule.setup - // may also be required. InternalCommunicationConfig internalCommunicationConfig = buildConfigObject(InternalCommunicationConfig.class); + Multibinder discoveryFilterBinder = newSetBinder(binder, HttpRequestFilter.class, ForDiscoveryClient.class); if (internalCommunicationConfig.isHttpsRequired() && internalCommunicationConfig.getKeyStorePath() == null && internalCommunicationConfig.getTrustStorePath() == null) { String sharedSecret = internalCommunicationConfig.getSharedSecret() .orElseThrow(() -> new IllegalArgumentException("Internal shared secret must be set when internal HTTPS is enabled")); configBinder(binder).bindConfigDefaults(HttpsConfig.class, config -> config.setAutomaticHttpsSharedSecret(sharedSecret)); - configBinder(binder).bindConfigGlobalDefaults(HttpClientConfig.class, config -> { - config.setHttp2Enabled(internalCommunicationConfig.isHttp2Enabled()); - config.setAutomaticHttpsSharedSecret(sharedSecret); - }); configBinder(binder).bindConfigGlobalDefaults(NodeConfig.class, config -> config.setInternalAddressSource(IP_ENCODED_AS_HOSTNAME)); - // rewrite discovery client requests to use IP encoded as hostname - newSetBinder(binder, HttpRequestFilter.class, ForDiscoveryClient.class).addBinding() - .toInstance(request -> Request.Builder.fromRequest(request) - .setUri(toIpEncodedAsHostnameUri(request.getUri())) - .build()); + discoveryFilterBinder.addBinding().to(DiscoveryEncodeAddressAsHostname.class); } - else { - configBinder(binder).bindConfigGlobalDefaults(HttpClientConfig.class, config -> { - config.setHttp2Enabled(internalCommunicationConfig.isHttp2Enabled()); - config.setKeyStorePath(internalCommunicationConfig.getKeyStorePath()); - config.setKeyStorePassword(internalCommunicationConfig.getKeyStorePassword()); - config.setTrustStorePath(internalCommunicationConfig.getTrustStorePath()); - config.setTrustStorePassword(internalCommunicationConfig.getTrustStorePassword()); - config.setAutomaticHttpsSharedSecret(null); - }); - } - + discoveryFilterBinder.addBinding().to(InternalAuthenticationManager.class); binder.bind(InternalAuthenticationManager.class); - httpClientBinder(binder).bindGlobalFilter(InternalAuthenticationManager.class); } - private static URI toIpEncodedAsHostnameUri(URI uri) + private static class DiscoveryEncodeAddressAsHostname + implements HttpRequestFilter { - if (!uri.getScheme().equals("https")) { - return uri; - } - try { - String host = uri.getHost(); - InetAddress inetAddress = InetAddress.getByName(host); - String addressAsHostname = encodeAddressAsHostname(inetAddress); - return new URI(uri.getScheme(), uri.getUserInfo(), addressAsHostname, uri.getPort(), uri.getPath(), uri.getQuery(), uri.getFragment()); + @Override + public Request filterRequest(Request request) + { + return Request.Builder.fromRequest(request) + .setUri(toIpEncodedAsHostnameUri(request.getUri())) + .build(); } - catch (UnknownHostException e) { - throw new UncheckedIOException(e); - } - catch (URISyntaxException e) { - throw new RuntimeException(e); + + private static URI toIpEncodedAsHostnameUri(URI uri) + { + if (!uri.getScheme().equals("https")) { + return uri; + } + try { + String host = uri.getHost(); + InetAddress inetAddress = InetAddress.getByName(host); + String addressAsHostname = encodeAddressAsHostname(inetAddress); + return new URI(uri.getScheme(), uri.getUserInfo(), addressAsHostname, uri.getPort(), uri.getPath(), uri.getQuery(), uri.getFragment()); + } + catch (UnknownHostException e) { + throw new UncheckedIOException(e); + } + catch (URISyntaxException e) { + throw new RuntimeException(e); + } } } } diff --git a/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java b/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java index e9a7b81a0136..3708cf5803a4 100644 --- a/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java +++ b/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java @@ -170,7 +170,6 @@ import static io.airlift.configuration.ConditionalModule.conditionalModule; import static io.airlift.configuration.ConfigBinder.configBinder; import static io.airlift.discovery.client.DiscoveryBinder.discoveryBinder; -import static io.airlift.http.client.HttpClientBinder.httpClientBinder; import static io.airlift.jaxrs.JaxrsBinder.jaxrsBinder; import static io.airlift.json.JsonBinder.jsonBinder; import static io.airlift.json.JsonCodecBinder.jsonCodecBinder; @@ -178,6 +177,7 @@ import static io.trino.execution.scheduler.NodeSchedulerConfig.NodeSchedulerPolicy.TOPOLOGY; import static io.trino.execution.scheduler.NodeSchedulerConfig.NodeSchedulerPolicy.UNIFORM; import static io.trino.operator.RetryPolicy.TASK; +import static io.trino.server.InternalCommunicationHttpClientModule.internalHttpClientModule; import static java.util.Objects.requireNonNull; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; @@ -255,12 +255,12 @@ protected void setup(Binder binder) binder.bind(DiscoveryNodeManager.class).in(Scopes.SINGLETON); binder.bind(InternalNodeManager.class).to(DiscoveryNodeManager.class).in(Scopes.SINGLETON); newExporter(binder).export(DiscoveryNodeManager.class).withGeneratedName(); - httpClientBinder(binder).bindHttpClient("node-manager", ForNodeManager.class) + install(internalHttpClientModule("node-manager", ForNodeManager.class) .withTracing() .withConfigDefaults(config -> { config.setIdleTimeout(new Duration(30, SECONDS)); config.setRequestTimeout(new Duration(10, SECONDS)); - }); + }).build()); // node scheduler // TODO: remove from NodePartitioningManager and move to CoordinatorModule @@ -339,7 +339,7 @@ protected void setup(Binder binder) // exchange client binder.bind(DirectExchangeClientSupplier.class).to(DirectExchangeClientFactory.class).in(Scopes.SINGLETON); - httpClientBinder(binder).bindHttpClient("exchange", ForExchange.class) + install(internalHttpClientModule("exchange", ForExchange.class) .withTracing() .withFilter(GenerateTraceTokenRequestFilter.class) .withConfigDefaults(config -> { @@ -347,7 +347,7 @@ protected void setup(Binder binder) config.setRequestTimeout(new Duration(10, SECONDS)); config.setMaxConnectionsPerServer(250); config.setMaxContentLength(DataSize.of(32, MEGABYTE)); - }); + }).build()); configBinder(binder).bindConfig(DirectExchangeClientConfig.class); binder.bind(ExchangeExecutionMBean.class).in(Scopes.SINGLETON); diff --git a/core/trino-main/src/main/java/io/trino/server/security/jwt/JwtAuthenticatorSupportModule.java b/core/trino-main/src/main/java/io/trino/server/security/jwt/JwtAuthenticatorSupportModule.java index bf4801d74aba..43cb56fdc8eb 100644 --- a/core/trino-main/src/main/java/io/trino/server/security/jwt/JwtAuthenticatorSupportModule.java +++ b/core/trino-main/src/main/java/io/trino/server/security/jwt/JwtAuthenticatorSupportModule.java @@ -54,19 +54,7 @@ private static class JwkModule @Override public void configure(Binder binder) { - httpClientBinder(binder) - .bindHttpClient("jwk", ForJwt.class) - // Reset HttpClient default configuration to override InternalCommunicationModule changes. - // Setting a keystore and/or a truststore for internal communication changes the default SSL configuration - // for all clients in the same guice context. This, however, does not make sense for this client which will - // very rarely use the same SSL setup as internal communication, so using the system default truststore - // makes more sense. - .withConfigDefaults(config -> config - .setKeyStorePath(null) - .setKeyStorePassword(null) - .setTrustStorePath(null) - .setTrustStorePassword(null) - .setAutomaticHttpsSharedSecret(null)); + httpClientBinder(binder).bindHttpClient("jwk", ForJwt.class); } @Provides diff --git a/core/trino-main/src/main/java/io/trino/server/security/oauth2/OAuth2ServiceModule.java b/core/trino-main/src/main/java/io/trino/server/security/oauth2/OAuth2ServiceModule.java index c6dbbdd0181d..6b16166df104 100644 --- a/core/trino-main/src/main/java/io/trino/server/security/oauth2/OAuth2ServiceModule.java +++ b/core/trino-main/src/main/java/io/trino/server/security/oauth2/OAuth2ServiceModule.java @@ -50,18 +50,7 @@ protected void setup(Binder binder) .in(Scopes.SINGLETON); install(conditionalModule(OAuth2Config.class, OAuth2Config::isEnableDiscovery, this::bindOidcDiscovery, this::bindStaticConfiguration)); install(conditionalModule(OAuth2Config.class, OAuth2Config::isEnableRefreshTokens, this::enableRefreshTokens, this::disableRefreshTokens)); - httpClientBinder(binder) - .bindHttpClient("oauth2-jwk", ForOAuth2.class) - // Reset to defaults to override InternalCommunicationModule changes to this client default configuration. - // Setting a keystore and/or a truststore for internal communication changes the default SSL configuration - // for all clients in this guice context. This does not make sense for this client which will very rarely - // use the same SSL configuration, so using the system default truststore makes more sense. - .withConfigDefaults(config -> config - .setKeyStorePath(null) - .setKeyStorePassword(null) - .setTrustStorePath(null) - .setTrustStorePassword(null) - .setAutomaticHttpsSharedSecret(null)); + httpClientBinder(binder).bindHttpClient("oauth2-jwk", ForOAuth2.class); } private void enableRefreshTokens(Binder binder) diff --git a/core/trino-main/src/test/java/io/trino/failuredetector/TestHeartbeatFailureDetector.java b/core/trino-main/src/test/java/io/trino/failuredetector/TestHeartbeatFailureDetector.java index c83777d1c5ef..9c69c91c5505 100644 --- a/core/trino-main/src/test/java/io/trino/failuredetector/TestHeartbeatFailureDetector.java +++ b/core/trino-main/src/test/java/io/trino/failuredetector/TestHeartbeatFailureDetector.java @@ -30,6 +30,7 @@ import io.trino.execution.QueryManagerConfig; import io.trino.failuredetector.HeartbeatFailureDetector.Stats; import io.trino.server.InternalCommunicationConfig; +import io.trino.server.security.SecurityConfig; import org.testng.annotations.Test; import javax.ws.rs.GET; @@ -61,6 +62,7 @@ public void testExcludesCurrentNode() new JaxrsModule(), new FailureDetectorModule(), binder -> { + configBinder(binder).bindConfig(SecurityConfig.class); configBinder(binder).bindConfig(InternalCommunicationConfig.class); configBinder(binder).bindConfig(QueryManagerConfig.class); discoveryBinder(binder).bindSelector("trino"); diff --git a/core/trino-main/src/test/java/io/trino/server/TestGenerateTokenFilter.java b/core/trino-main/src/test/java/io/trino/server/TestGenerateTokenFilter.java index 6ecf788fb8c1..a27c6aa7b8bd 100644 --- a/core/trino-main/src/test/java/io/trino/server/TestGenerateTokenFilter.java +++ b/core/trino-main/src/test/java/io/trino/server/TestGenerateTokenFilter.java @@ -66,9 +66,9 @@ public void setup() // extract the filter List filters = httpClient.getRequestFilters(); - assertEquals(filters.size(), 3); - assertInstanceOf(filters.get(2), GenerateTraceTokenRequestFilter.class); - filter = (GenerateTraceTokenRequestFilter) filters.get(2); + assertEquals(filters.size(), 2); + assertInstanceOf(filters.get(1), GenerateTraceTokenRequestFilter.class); + filter = (GenerateTraceTokenRequestFilter) filters.get(1); } @AfterClass(alwaysRun = true)