From 6c8dd6ebfc24741746b0692045198aba5d6af6f7 Mon Sep 17 00:00:00 2001 From: "Mateusz \"Serafin\" Gajewski" Date: Sat, 22 Apr 2023 20:38:10 +0200 Subject: [PATCH] Refactor internal http client usage The previous approach was dangerous, as it was modifying configuration for all HTTP clients and adding global filters even for HTTP clients there were not supposed to be used for internal communication. With this change, it's now intentional which HTTP clients are supposed to use secured internal communication. --- .../FailureDetectorModule.java | 14 +- .../io/trino/server/CoordinatorModule.java | 14 +- ...InternalCommunicationHttpClientModule.java | 142 ++++++++++++++++++ .../server/InternalCommunicationModule.java | 30 +--- .../io/trino/server/ServerMainModule.java | 10 +- .../jwt/JwtAuthenticatorSupportModule.java | 14 +- .../security/oauth2/OAuth2ServiceModule.java | 13 +- .../TestHeartbeatFailureDetector.java | 2 + .../trino/server/TestGenerateTokenFilter.java | 6 +- 9 files changed, 172 insertions(+), 73 deletions(-) create mode 100644 core/trino-main/src/main/java/io/trino/server/InternalCommunicationHttpClientModule.java diff --git a/core/trino-main/src/main/java/io/trino/failuredetector/FailureDetectorModule.java b/core/trino-main/src/main/java/io/trino/failuredetector/FailureDetectorModule.java index 259280f11093..37f0d031b61c 100644 --- a/core/trino-main/src/main/java/io/trino/failuredetector/FailureDetectorModule.java +++ b/core/trino-main/src/main/java/io/trino/failuredetector/FailureDetectorModule.java @@ -14,22 +14,22 @@ package io.trino.failuredetector; import com.google.inject.Binder; -import com.google.inject.Module; import com.google.inject.Scopes; +import io.airlift.configuration.AbstractConfigurationAwareModule; import org.weakref.jmx.guice.ExportBinder; import static io.airlift.configuration.ConfigBinder.configBinder; -import static io.airlift.http.client.HttpClientBinder.httpClientBinder; +import static io.trino.server.InternalCommunicationHttpClientModule.internalHttpClientModule; public class FailureDetectorModule - implements Module + extends AbstractConfigurationAwareModule { @Override - public void configure(Binder binder) + protected void setup(Binder binder) { - httpClientBinder(binder) - .bindHttpClient("failure-detector", ForFailureDetector.class) - .withTracing(); + install(internalHttpClientModule("failure-detector", ForFailureDetector.class) + .withTracing() + .build()); configBinder(binder).bindConfig(FailureDetectorConfig.class); diff --git a/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java b/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java index 7a4076566372..e50401d42a08 100644 --- a/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java +++ b/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java @@ -129,12 +129,12 @@ import static io.airlift.configuration.ConditionalModule.conditionalModule; import static io.airlift.configuration.ConfigBinder.configBinder; import static io.airlift.discovery.client.DiscoveryBinder.discoveryBinder; -import static io.airlift.http.client.HttpClientBinder.httpClientBinder; import static io.airlift.jaxrs.JaxrsBinder.jaxrsBinder; import static io.airlift.json.JsonCodecBinder.jsonCodecBinder; import static io.airlift.units.DataSize.Unit.MEGABYTE; import static io.trino.execution.scheduler.NodeSchedulerConfig.NodeAllocatorType.BIN_PACKING; import static io.trino.execution.scheduler.NodeSchedulerConfig.NodeAllocatorType.FIXED_COUNT; +import static io.trino.server.InternalCommunicationHttpClientModule.internalHttpClientModule; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; import static java.util.concurrent.Executors.newSingleThreadScheduledExecutor; @@ -167,10 +167,10 @@ protected void setup(Binder binder) }); // failure detector - binder.install(new FailureDetectorModule()); + install(new FailureDetectorModule()); jaxrsBinder(binder).bind(NodeResource.class); jaxrsBinder(binder).bind(WorkerResource.class); - httpClientBinder(binder).bindHttpClient("workerInfo", ForWorkerInfo.class); + install(internalHttpClientModule("workerInfo", ForWorkerInfo.class).build()); // query monitor jsonCodecBinder(binder).bindJsonCodec(ExecutionFailureInfo.class); @@ -206,12 +206,12 @@ protected void setup(Binder binder) // cluster memory manager binder.bind(ClusterMemoryManager.class).in(Scopes.SINGLETON); binder.bind(ClusterMemoryPoolManager.class).to(ClusterMemoryManager.class).in(Scopes.SINGLETON); - httpClientBinder(binder).bindHttpClient("memoryManager", ForMemoryManager.class) + install(internalHttpClientModule("memoryManager", ForMemoryManager.class) .withTracing() .withConfigDefaults(config -> { config.setIdleTimeout(new Duration(30, SECONDS)); config.setRequestTimeout(new Duration(10, SECONDS)); - }); + }).build()); bindLowMemoryTaskKiller(LowMemoryTaskKillerPolicy.NONE, NoneLowMemoryKiller.class); bindLowMemoryTaskKiller(LowMemoryTaskKillerPolicy.TOTAL_RESERVATION_ON_BLOCKED_NODES, TotalReservationOnBlockedNodesTaskLowMemoryKiller.class); @@ -293,14 +293,14 @@ protected void setup(Binder binder) binder.bind(RemoteTaskStats.class).in(Scopes.SINGLETON); newExporter(binder).export(RemoteTaskStats.class).withGeneratedName(); - httpClientBinder(binder).bindHttpClient("scheduler", ForScheduler.class) + install(internalHttpClientModule("scheduler", ForScheduler.class) .withTracing() .withFilter(GenerateTraceTokenRequestFilter.class) .withConfigDefaults(config -> { config.setIdleTimeout(new Duration(30, SECONDS)); config.setRequestTimeout(new Duration(10, SECONDS)); config.setMaxConnectionsPerServer(250); - }); + }).build()); binder.bind(ScheduledExecutorService.class).annotatedWith(ForScheduler.class) .toInstance(newSingleThreadScheduledExecutor(threadsNamed("stage-scheduler"))); diff --git a/core/trino-main/src/main/java/io/trino/server/InternalCommunicationHttpClientModule.java b/core/trino-main/src/main/java/io/trino/server/InternalCommunicationHttpClientModule.java new file mode 100644 index 000000000000..632231c648e6 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/server/InternalCommunicationHttpClientModule.java @@ -0,0 +1,142 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.server; + +import com.google.common.collect.ImmutableList; +import com.google.inject.Binder; +import io.airlift.configuration.AbstractConfigurationAwareModule; +import io.airlift.http.client.HttpClientBinder.HttpClientBindingBuilder; +import io.airlift.http.client.HttpClientConfig; +import io.airlift.http.client.HttpRequestFilter; + +import java.lang.annotation.Annotation; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Consumer; + +import static io.airlift.http.client.HttpClientBinder.httpClientBinder; +import static java.util.Objects.requireNonNull; + +public class InternalCommunicationHttpClientModule + extends AbstractConfigurationAwareModule +{ + private final String clientName; + private final Class annotation; + private final boolean withTracing; + private final Consumer configDefaults; + private final List> filters; + + private InternalCommunicationHttpClientModule( + String clientName, + Class annotation, + boolean withTracing, + Consumer configDefaults, + List> filters) + { + this.clientName = requireNonNull(clientName, "clientName is null"); + this.annotation = requireNonNull(annotation, "annotation is null"); + this.withTracing = withTracing; + this.configDefaults = requireNonNull(configDefaults, "configDefaults is null"); + this.filters = ImmutableList.copyOf(requireNonNull(filters, "filters is null")); + } + + @Override + protected void setup(Binder binder) + { + HttpClientBindingBuilder httpClientBindingBuilder = httpClientBinder(binder).bindHttpClient(clientName, annotation); + InternalCommunicationConfig internalCommunicationConfig = buildConfigObject(InternalCommunicationConfig.class); + httpClientBindingBuilder.withConfigDefaults(httpConfig -> { + configureClient(httpConfig, internalCommunicationConfig); + configDefaults.accept(httpConfig); + }); + + httpClientBindingBuilder.addFilterBinding().to(InternalAuthenticationManager.class); + + if (withTracing) { + httpClientBindingBuilder.withTracing(); + } + + filters.forEach(httpClientBindingBuilder::withFilter); + } + + static void configureClient(HttpClientConfig httpConfig, InternalCommunicationConfig internalCommunicationConfig) + { + httpConfig.setHttp2Enabled(internalCommunicationConfig.isHttp2Enabled()); + if (internalCommunicationConfig.isHttpsRequired() && internalCommunicationConfig.getKeyStorePath() == null && internalCommunicationConfig.getTrustStorePath() == null) { + configureClientForAutomaticHttps(httpConfig, internalCommunicationConfig); + } + else { + configureClientForManualHttps(httpConfig, internalCommunicationConfig); + } + } + + private static void configureClientForAutomaticHttps(HttpClientConfig httpConfig, InternalCommunicationConfig internalCommunicationConfig) + { + String sharedSecret = internalCommunicationConfig.getSharedSecret() + .orElseThrow(() -> new IllegalArgumentException("Internal shared secret must be set when internal HTTPS is enabled")); + httpConfig.setAutomaticHttpsSharedSecret(sharedSecret); + } + + private static void configureClientForManualHttps(HttpClientConfig httpConfig, InternalCommunicationConfig internalCommunicationConfig) + { + httpConfig.setKeyStorePath(internalCommunicationConfig.getKeyStorePath()); + httpConfig.setKeyStorePassword(internalCommunicationConfig.getKeyStorePassword()); + httpConfig.setTrustStorePath(internalCommunicationConfig.getTrustStorePath()); + httpConfig.setTrustStorePassword(internalCommunicationConfig.getTrustStorePassword()); + httpConfig.setAutomaticHttpsSharedSecret(null); + } + + public static class Builder + { + private final String clientName; + private final Class annotation; + private boolean withTracing; + private Consumer configDefaults = config -> {}; + private final List> filters = new ArrayList<>(); + + private Builder(String clientName, Class annotation) + { + this.clientName = requireNonNull(clientName, "clientName is null"); + this.annotation = requireNonNull(annotation, "annotation is null"); + } + + public Builder withTracing() + { + this.withTracing = true; + return this; + } + + public Builder withConfigDefaults(Consumer configDefaults) + { + this.configDefaults = requireNonNull(configDefaults, "configDefaults is null"); + return this; + } + + public Builder withFilter(Class requestFilter) + { + this.filters.add(requestFilter); + return this; + } + + public InternalCommunicationHttpClientModule build() + { + return new InternalCommunicationHttpClientModule(clientName, annotation, withTracing, configDefaults, filters); + } + } + + public static InternalCommunicationHttpClientModule.Builder internalHttpClientModule(String clientName, Class annotation) + { + return new Builder(clientName, annotation); + } +} diff --git a/core/trino-main/src/main/java/io/trino/server/InternalCommunicationModule.java b/core/trino-main/src/main/java/io/trino/server/InternalCommunicationModule.java index a09276a55f10..e2194872ce10 100644 --- a/core/trino-main/src/main/java/io/trino/server/InternalCommunicationModule.java +++ b/core/trino-main/src/main/java/io/trino/server/InternalCommunicationModule.java @@ -14,9 +14,9 @@ package io.trino.server; import com.google.inject.Binder; +import com.google.inject.multibindings.Multibinder; import io.airlift.configuration.AbstractConfigurationAwareModule; import io.airlift.discovery.client.ForDiscoveryClient; -import io.airlift.http.client.HttpClientConfig; import io.airlift.http.client.HttpRequestFilter; import io.airlift.http.client.Request; import io.airlift.http.server.HttpsConfig; @@ -30,7 +30,6 @@ import static com.google.inject.multibindings.Multibinder.newSetBinder; import static io.airlift.configuration.ConfigBinder.configBinder; -import static io.airlift.http.client.HttpClientBinder.httpClientBinder; import static io.airlift.node.AddressToHostname.encodeAddressAsHostname; import static io.airlift.node.NodeConfig.AddressSource.IP_ENCODED_AS_HOSTNAME; @@ -40,38 +39,17 @@ public class InternalCommunicationModule @Override protected void setup(Binder binder) { - // Set defaults for all HttpClients in the same guice context - // so in case of any additions or alternations here an update in: - // io.trino.server.security.jwt.JwtAuthenticatorSupportModule.JwkModule.configure - // and - // io.trino.server.security.oauth2.OAuth2ServiceModule.setup - // may also be required. InternalCommunicationConfig internalCommunicationConfig = buildConfigObject(InternalCommunicationConfig.class); + Multibinder discoveryFilterBinder = newSetBinder(binder, HttpRequestFilter.class, ForDiscoveryClient.class); if (internalCommunicationConfig.isHttpsRequired() && internalCommunicationConfig.getKeyStorePath() == null && internalCommunicationConfig.getTrustStorePath() == null) { String sharedSecret = internalCommunicationConfig.getSharedSecret() .orElseThrow(() -> new IllegalArgumentException("Internal shared secret must be set when internal HTTPS is enabled")); configBinder(binder).bindConfigDefaults(HttpsConfig.class, config -> config.setAutomaticHttpsSharedSecret(sharedSecret)); - configBinder(binder).bindConfigGlobalDefaults(HttpClientConfig.class, config -> { - config.setHttp2Enabled(internalCommunicationConfig.isHttp2Enabled()); - config.setAutomaticHttpsSharedSecret(sharedSecret); - }); configBinder(binder).bindConfigGlobalDefaults(NodeConfig.class, config -> config.setInternalAddressSource(IP_ENCODED_AS_HOSTNAME)); - // rewrite discovery client requests to use IP encoded as hostname - newSetBinder(binder, HttpRequestFilter.class, ForDiscoveryClient.class).addBinding().to(DiscoveryEncodeAddressAsHostname.class); + discoveryFilterBinder.addBinding().to(DiscoveryEncodeAddressAsHostname.class); } - else { - configBinder(binder).bindConfigGlobalDefaults(HttpClientConfig.class, config -> { - config.setHttp2Enabled(internalCommunicationConfig.isHttp2Enabled()); - config.setKeyStorePath(internalCommunicationConfig.getKeyStorePath()); - config.setKeyStorePassword(internalCommunicationConfig.getKeyStorePassword()); - config.setTrustStorePath(internalCommunicationConfig.getTrustStorePath()); - config.setTrustStorePassword(internalCommunicationConfig.getTrustStorePassword()); - config.setAutomaticHttpsSharedSecret(null); - }); - } - + discoveryFilterBinder.addBinding().to(InternalAuthenticationManager.class); binder.bind(InternalAuthenticationManager.class); - httpClientBinder(binder).bindGlobalFilter(InternalAuthenticationManager.class); } private static class DiscoveryEncodeAddressAsHostname diff --git a/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java b/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java index e9a7b81a0136..3708cf5803a4 100644 --- a/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java +++ b/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java @@ -170,7 +170,6 @@ import static io.airlift.configuration.ConditionalModule.conditionalModule; import static io.airlift.configuration.ConfigBinder.configBinder; import static io.airlift.discovery.client.DiscoveryBinder.discoveryBinder; -import static io.airlift.http.client.HttpClientBinder.httpClientBinder; import static io.airlift.jaxrs.JaxrsBinder.jaxrsBinder; import static io.airlift.json.JsonBinder.jsonBinder; import static io.airlift.json.JsonCodecBinder.jsonCodecBinder; @@ -178,6 +177,7 @@ import static io.trino.execution.scheduler.NodeSchedulerConfig.NodeSchedulerPolicy.TOPOLOGY; import static io.trino.execution.scheduler.NodeSchedulerConfig.NodeSchedulerPolicy.UNIFORM; import static io.trino.operator.RetryPolicy.TASK; +import static io.trino.server.InternalCommunicationHttpClientModule.internalHttpClientModule; import static java.util.Objects.requireNonNull; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; @@ -255,12 +255,12 @@ protected void setup(Binder binder) binder.bind(DiscoveryNodeManager.class).in(Scopes.SINGLETON); binder.bind(InternalNodeManager.class).to(DiscoveryNodeManager.class).in(Scopes.SINGLETON); newExporter(binder).export(DiscoveryNodeManager.class).withGeneratedName(); - httpClientBinder(binder).bindHttpClient("node-manager", ForNodeManager.class) + install(internalHttpClientModule("node-manager", ForNodeManager.class) .withTracing() .withConfigDefaults(config -> { config.setIdleTimeout(new Duration(30, SECONDS)); config.setRequestTimeout(new Duration(10, SECONDS)); - }); + }).build()); // node scheduler // TODO: remove from NodePartitioningManager and move to CoordinatorModule @@ -339,7 +339,7 @@ protected void setup(Binder binder) // exchange client binder.bind(DirectExchangeClientSupplier.class).to(DirectExchangeClientFactory.class).in(Scopes.SINGLETON); - httpClientBinder(binder).bindHttpClient("exchange", ForExchange.class) + install(internalHttpClientModule("exchange", ForExchange.class) .withTracing() .withFilter(GenerateTraceTokenRequestFilter.class) .withConfigDefaults(config -> { @@ -347,7 +347,7 @@ protected void setup(Binder binder) config.setRequestTimeout(new Duration(10, SECONDS)); config.setMaxConnectionsPerServer(250); config.setMaxContentLength(DataSize.of(32, MEGABYTE)); - }); + }).build()); configBinder(binder).bindConfig(DirectExchangeClientConfig.class); binder.bind(ExchangeExecutionMBean.class).in(Scopes.SINGLETON); diff --git a/core/trino-main/src/main/java/io/trino/server/security/jwt/JwtAuthenticatorSupportModule.java b/core/trino-main/src/main/java/io/trino/server/security/jwt/JwtAuthenticatorSupportModule.java index bf4801d74aba..43cb56fdc8eb 100644 --- a/core/trino-main/src/main/java/io/trino/server/security/jwt/JwtAuthenticatorSupportModule.java +++ b/core/trino-main/src/main/java/io/trino/server/security/jwt/JwtAuthenticatorSupportModule.java @@ -54,19 +54,7 @@ private static class JwkModule @Override public void configure(Binder binder) { - httpClientBinder(binder) - .bindHttpClient("jwk", ForJwt.class) - // Reset HttpClient default configuration to override InternalCommunicationModule changes. - // Setting a keystore and/or a truststore for internal communication changes the default SSL configuration - // for all clients in the same guice context. This, however, does not make sense for this client which will - // very rarely use the same SSL setup as internal communication, so using the system default truststore - // makes more sense. - .withConfigDefaults(config -> config - .setKeyStorePath(null) - .setKeyStorePassword(null) - .setTrustStorePath(null) - .setTrustStorePassword(null) - .setAutomaticHttpsSharedSecret(null)); + httpClientBinder(binder).bindHttpClient("jwk", ForJwt.class); } @Provides diff --git a/core/trino-main/src/main/java/io/trino/server/security/oauth2/OAuth2ServiceModule.java b/core/trino-main/src/main/java/io/trino/server/security/oauth2/OAuth2ServiceModule.java index c6dbbdd0181d..6b16166df104 100644 --- a/core/trino-main/src/main/java/io/trino/server/security/oauth2/OAuth2ServiceModule.java +++ b/core/trino-main/src/main/java/io/trino/server/security/oauth2/OAuth2ServiceModule.java @@ -50,18 +50,7 @@ protected void setup(Binder binder) .in(Scopes.SINGLETON); install(conditionalModule(OAuth2Config.class, OAuth2Config::isEnableDiscovery, this::bindOidcDiscovery, this::bindStaticConfiguration)); install(conditionalModule(OAuth2Config.class, OAuth2Config::isEnableRefreshTokens, this::enableRefreshTokens, this::disableRefreshTokens)); - httpClientBinder(binder) - .bindHttpClient("oauth2-jwk", ForOAuth2.class) - // Reset to defaults to override InternalCommunicationModule changes to this client default configuration. - // Setting a keystore and/or a truststore for internal communication changes the default SSL configuration - // for all clients in this guice context. This does not make sense for this client which will very rarely - // use the same SSL configuration, so using the system default truststore makes more sense. - .withConfigDefaults(config -> config - .setKeyStorePath(null) - .setKeyStorePassword(null) - .setTrustStorePath(null) - .setTrustStorePassword(null) - .setAutomaticHttpsSharedSecret(null)); + httpClientBinder(binder).bindHttpClient("oauth2-jwk", ForOAuth2.class); } private void enableRefreshTokens(Binder binder) diff --git a/core/trino-main/src/test/java/io/trino/failuredetector/TestHeartbeatFailureDetector.java b/core/trino-main/src/test/java/io/trino/failuredetector/TestHeartbeatFailureDetector.java index c83777d1c5ef..9c69c91c5505 100644 --- a/core/trino-main/src/test/java/io/trino/failuredetector/TestHeartbeatFailureDetector.java +++ b/core/trino-main/src/test/java/io/trino/failuredetector/TestHeartbeatFailureDetector.java @@ -30,6 +30,7 @@ import io.trino.execution.QueryManagerConfig; import io.trino.failuredetector.HeartbeatFailureDetector.Stats; import io.trino.server.InternalCommunicationConfig; +import io.trino.server.security.SecurityConfig; import org.testng.annotations.Test; import javax.ws.rs.GET; @@ -61,6 +62,7 @@ public void testExcludesCurrentNode() new JaxrsModule(), new FailureDetectorModule(), binder -> { + configBinder(binder).bindConfig(SecurityConfig.class); configBinder(binder).bindConfig(InternalCommunicationConfig.class); configBinder(binder).bindConfig(QueryManagerConfig.class); discoveryBinder(binder).bindSelector("trino"); diff --git a/core/trino-main/src/test/java/io/trino/server/TestGenerateTokenFilter.java b/core/trino-main/src/test/java/io/trino/server/TestGenerateTokenFilter.java index 6ecf788fb8c1..a27c6aa7b8bd 100644 --- a/core/trino-main/src/test/java/io/trino/server/TestGenerateTokenFilter.java +++ b/core/trino-main/src/test/java/io/trino/server/TestGenerateTokenFilter.java @@ -66,9 +66,9 @@ public void setup() // extract the filter List filters = httpClient.getRequestFilters(); - assertEquals(filters.size(), 3); - assertInstanceOf(filters.get(2), GenerateTraceTokenRequestFilter.class); - filter = (GenerateTraceTokenRequestFilter) filters.get(2); + assertEquals(filters.size(), 2); + assertInstanceOf(filters.get(1), GenerateTraceTokenRequestFilter.class); + filter = (GenerateTraceTokenRequestFilter) filters.get(1); } @AfterClass(alwaysRun = true)