Skip to content

Commit

Permalink
Refactor internal http client usage
Browse files Browse the repository at this point in the history
The previous approach was dangerous, as it was modifying configuration for all HTTP clients and adding global filters even for HTTP clients there were not
supposed to be used for internal communication.

With this change, it's now intentional which HTTP clients are supposed to use secured internal communication.
  • Loading branch information
wendigo authored and kokosing committed Apr 26, 2023
1 parent d3f0f99 commit 6c8dd6e
Show file tree
Hide file tree
Showing 9 changed files with 172 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,22 @@
package io.trino.failuredetector;

import com.google.inject.Binder;
import com.google.inject.Module;
import com.google.inject.Scopes;
import io.airlift.configuration.AbstractConfigurationAwareModule;
import org.weakref.jmx.guice.ExportBinder;

import static io.airlift.configuration.ConfigBinder.configBinder;
import static io.airlift.http.client.HttpClientBinder.httpClientBinder;
import static io.trino.server.InternalCommunicationHttpClientModule.internalHttpClientModule;

public class FailureDetectorModule
implements Module
extends AbstractConfigurationAwareModule
{
@Override
public void configure(Binder binder)
protected void setup(Binder binder)
{
httpClientBinder(binder)
.bindHttpClient("failure-detector", ForFailureDetector.class)
.withTracing();
install(internalHttpClientModule("failure-detector", ForFailureDetector.class)
.withTracing()
.build());

configBinder(binder).bindConfig(FailureDetectorConfig.class);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,12 @@
import static io.airlift.configuration.ConditionalModule.conditionalModule;
import static io.airlift.configuration.ConfigBinder.configBinder;
import static io.airlift.discovery.client.DiscoveryBinder.discoveryBinder;
import static io.airlift.http.client.HttpClientBinder.httpClientBinder;
import static io.airlift.jaxrs.JaxrsBinder.jaxrsBinder;
import static io.airlift.json.JsonCodecBinder.jsonCodecBinder;
import static io.airlift.units.DataSize.Unit.MEGABYTE;
import static io.trino.execution.scheduler.NodeSchedulerConfig.NodeAllocatorType.BIN_PACKING;
import static io.trino.execution.scheduler.NodeSchedulerConfig.NodeAllocatorType.FIXED_COUNT;
import static io.trino.server.InternalCommunicationHttpClientModule.internalHttpClientModule;
import static java.util.concurrent.Executors.newCachedThreadPool;
import static java.util.concurrent.Executors.newScheduledThreadPool;
import static java.util.concurrent.Executors.newSingleThreadScheduledExecutor;
Expand Down Expand Up @@ -167,10 +167,10 @@ protected void setup(Binder binder)
});

// failure detector
binder.install(new FailureDetectorModule());
install(new FailureDetectorModule());
jaxrsBinder(binder).bind(NodeResource.class);
jaxrsBinder(binder).bind(WorkerResource.class);
httpClientBinder(binder).bindHttpClient("workerInfo", ForWorkerInfo.class);
install(internalHttpClientModule("workerInfo", ForWorkerInfo.class).build());

// query monitor
jsonCodecBinder(binder).bindJsonCodec(ExecutionFailureInfo.class);
Expand Down Expand Up @@ -206,12 +206,12 @@ protected void setup(Binder binder)
// cluster memory manager
binder.bind(ClusterMemoryManager.class).in(Scopes.SINGLETON);
binder.bind(ClusterMemoryPoolManager.class).to(ClusterMemoryManager.class).in(Scopes.SINGLETON);
httpClientBinder(binder).bindHttpClient("memoryManager", ForMemoryManager.class)
install(internalHttpClientModule("memoryManager", ForMemoryManager.class)
.withTracing()
.withConfigDefaults(config -> {
config.setIdleTimeout(new Duration(30, SECONDS));
config.setRequestTimeout(new Duration(10, SECONDS));
});
}).build());

bindLowMemoryTaskKiller(LowMemoryTaskKillerPolicy.NONE, NoneLowMemoryKiller.class);
bindLowMemoryTaskKiller(LowMemoryTaskKillerPolicy.TOTAL_RESERVATION_ON_BLOCKED_NODES, TotalReservationOnBlockedNodesTaskLowMemoryKiller.class);
Expand Down Expand Up @@ -293,14 +293,14 @@ protected void setup(Binder binder)
binder.bind(RemoteTaskStats.class).in(Scopes.SINGLETON);
newExporter(binder).export(RemoteTaskStats.class).withGeneratedName();

httpClientBinder(binder).bindHttpClient("scheduler", ForScheduler.class)
install(internalHttpClientModule("scheduler", ForScheduler.class)
.withTracing()
.withFilter(GenerateTraceTokenRequestFilter.class)
.withConfigDefaults(config -> {
config.setIdleTimeout(new Duration(30, SECONDS));
config.setRequestTimeout(new Duration(10, SECONDS));
config.setMaxConnectionsPerServer(250);
});
}).build());

binder.bind(ScheduledExecutorService.class).annotatedWith(ForScheduler.class)
.toInstance(newSingleThreadScheduledExecutor(threadsNamed("stage-scheduler")));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.server;

import com.google.common.collect.ImmutableList;
import com.google.inject.Binder;
import io.airlift.configuration.AbstractConfigurationAwareModule;
import io.airlift.http.client.HttpClientBinder.HttpClientBindingBuilder;
import io.airlift.http.client.HttpClientConfig;
import io.airlift.http.client.HttpRequestFilter;

import java.lang.annotation.Annotation;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Consumer;

import static io.airlift.http.client.HttpClientBinder.httpClientBinder;
import static java.util.Objects.requireNonNull;

public class InternalCommunicationHttpClientModule
extends AbstractConfigurationAwareModule
{
private final String clientName;
private final Class<? extends Annotation> annotation;
private final boolean withTracing;
private final Consumer<HttpClientConfig> configDefaults;
private final List<Class<? extends HttpRequestFilter>> filters;

private InternalCommunicationHttpClientModule(
String clientName,
Class<? extends Annotation> annotation,
boolean withTracing,
Consumer<HttpClientConfig> configDefaults,
List<Class<? extends HttpRequestFilter>> filters)
{
this.clientName = requireNonNull(clientName, "clientName is null");
this.annotation = requireNonNull(annotation, "annotation is null");
this.withTracing = withTracing;
this.configDefaults = requireNonNull(configDefaults, "configDefaults is null");
this.filters = ImmutableList.copyOf(requireNonNull(filters, "filters is null"));
}

@Override
protected void setup(Binder binder)
{
HttpClientBindingBuilder httpClientBindingBuilder = httpClientBinder(binder).bindHttpClient(clientName, annotation);
InternalCommunicationConfig internalCommunicationConfig = buildConfigObject(InternalCommunicationConfig.class);
httpClientBindingBuilder.withConfigDefaults(httpConfig -> {
configureClient(httpConfig, internalCommunicationConfig);
configDefaults.accept(httpConfig);
});

httpClientBindingBuilder.addFilterBinding().to(InternalAuthenticationManager.class);

if (withTracing) {
httpClientBindingBuilder.withTracing();
}

filters.forEach(httpClientBindingBuilder::withFilter);
}

static void configureClient(HttpClientConfig httpConfig, InternalCommunicationConfig internalCommunicationConfig)
{
httpConfig.setHttp2Enabled(internalCommunicationConfig.isHttp2Enabled());
if (internalCommunicationConfig.isHttpsRequired() && internalCommunicationConfig.getKeyStorePath() == null && internalCommunicationConfig.getTrustStorePath() == null) {
configureClientForAutomaticHttps(httpConfig, internalCommunicationConfig);
}
else {
configureClientForManualHttps(httpConfig, internalCommunicationConfig);
}
}

private static void configureClientForAutomaticHttps(HttpClientConfig httpConfig, InternalCommunicationConfig internalCommunicationConfig)
{
String sharedSecret = internalCommunicationConfig.getSharedSecret()
.orElseThrow(() -> new IllegalArgumentException("Internal shared secret must be set when internal HTTPS is enabled"));
httpConfig.setAutomaticHttpsSharedSecret(sharedSecret);
}

private static void configureClientForManualHttps(HttpClientConfig httpConfig, InternalCommunicationConfig internalCommunicationConfig)
{
httpConfig.setKeyStorePath(internalCommunicationConfig.getKeyStorePath());
httpConfig.setKeyStorePassword(internalCommunicationConfig.getKeyStorePassword());
httpConfig.setTrustStorePath(internalCommunicationConfig.getTrustStorePath());
httpConfig.setTrustStorePassword(internalCommunicationConfig.getTrustStorePassword());
httpConfig.setAutomaticHttpsSharedSecret(null);
}

public static class Builder
{
private final String clientName;
private final Class<? extends Annotation> annotation;
private boolean withTracing;
private Consumer<HttpClientConfig> configDefaults = config -> {};
private final List<Class<? extends HttpRequestFilter>> filters = new ArrayList<>();

private Builder(String clientName, Class<? extends Annotation> annotation)
{
this.clientName = requireNonNull(clientName, "clientName is null");
this.annotation = requireNonNull(annotation, "annotation is null");
}

public Builder withTracing()
{
this.withTracing = true;
return this;
}

public Builder withConfigDefaults(Consumer<HttpClientConfig> configDefaults)
{
this.configDefaults = requireNonNull(configDefaults, "configDefaults is null");
return this;
}

public Builder withFilter(Class<? extends HttpRequestFilter> requestFilter)
{
this.filters.add(requestFilter);
return this;
}

public InternalCommunicationHttpClientModule build()
{
return new InternalCommunicationHttpClientModule(clientName, annotation, withTracing, configDefaults, filters);
}
}

public static InternalCommunicationHttpClientModule.Builder internalHttpClientModule(String clientName, Class<? extends Annotation> annotation)
{
return new Builder(clientName, annotation);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
package io.trino.server;

import com.google.inject.Binder;
import com.google.inject.multibindings.Multibinder;
import io.airlift.configuration.AbstractConfigurationAwareModule;
import io.airlift.discovery.client.ForDiscoveryClient;
import io.airlift.http.client.HttpClientConfig;
import io.airlift.http.client.HttpRequestFilter;
import io.airlift.http.client.Request;
import io.airlift.http.server.HttpsConfig;
Expand All @@ -30,7 +30,6 @@

import static com.google.inject.multibindings.Multibinder.newSetBinder;
import static io.airlift.configuration.ConfigBinder.configBinder;
import static io.airlift.http.client.HttpClientBinder.httpClientBinder;
import static io.airlift.node.AddressToHostname.encodeAddressAsHostname;
import static io.airlift.node.NodeConfig.AddressSource.IP_ENCODED_AS_HOSTNAME;

Expand All @@ -40,38 +39,17 @@ public class InternalCommunicationModule
@Override
protected void setup(Binder binder)
{
// Set defaults for all HttpClients in the same guice context
// so in case of any additions or alternations here an update in:
// io.trino.server.security.jwt.JwtAuthenticatorSupportModule.JwkModule.configure
// and
// io.trino.server.security.oauth2.OAuth2ServiceModule.setup
// may also be required.
InternalCommunicationConfig internalCommunicationConfig = buildConfigObject(InternalCommunicationConfig.class);
Multibinder<HttpRequestFilter> discoveryFilterBinder = newSetBinder(binder, HttpRequestFilter.class, ForDiscoveryClient.class);
if (internalCommunicationConfig.isHttpsRequired() && internalCommunicationConfig.getKeyStorePath() == null && internalCommunicationConfig.getTrustStorePath() == null) {
String sharedSecret = internalCommunicationConfig.getSharedSecret()
.orElseThrow(() -> new IllegalArgumentException("Internal shared secret must be set when internal HTTPS is enabled"));
configBinder(binder).bindConfigDefaults(HttpsConfig.class, config -> config.setAutomaticHttpsSharedSecret(sharedSecret));
configBinder(binder).bindConfigGlobalDefaults(HttpClientConfig.class, config -> {
config.setHttp2Enabled(internalCommunicationConfig.isHttp2Enabled());
config.setAutomaticHttpsSharedSecret(sharedSecret);
});
configBinder(binder).bindConfigGlobalDefaults(NodeConfig.class, config -> config.setInternalAddressSource(IP_ENCODED_AS_HOSTNAME));
// rewrite discovery client requests to use IP encoded as hostname
newSetBinder(binder, HttpRequestFilter.class, ForDiscoveryClient.class).addBinding().to(DiscoveryEncodeAddressAsHostname.class);
discoveryFilterBinder.addBinding().to(DiscoveryEncodeAddressAsHostname.class);
}
else {
configBinder(binder).bindConfigGlobalDefaults(HttpClientConfig.class, config -> {
config.setHttp2Enabled(internalCommunicationConfig.isHttp2Enabled());
config.setKeyStorePath(internalCommunicationConfig.getKeyStorePath());
config.setKeyStorePassword(internalCommunicationConfig.getKeyStorePassword());
config.setTrustStorePath(internalCommunicationConfig.getTrustStorePath());
config.setTrustStorePassword(internalCommunicationConfig.getTrustStorePassword());
config.setAutomaticHttpsSharedSecret(null);
});
}

discoveryFilterBinder.addBinding().to(InternalAuthenticationManager.class);
binder.bind(InternalAuthenticationManager.class);
httpClientBinder(binder).bindGlobalFilter(InternalAuthenticationManager.class);
}

private static class DiscoveryEncodeAddressAsHostname
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,14 +170,14 @@
import static io.airlift.configuration.ConditionalModule.conditionalModule;
import static io.airlift.configuration.ConfigBinder.configBinder;
import static io.airlift.discovery.client.DiscoveryBinder.discoveryBinder;
import static io.airlift.http.client.HttpClientBinder.httpClientBinder;
import static io.airlift.jaxrs.JaxrsBinder.jaxrsBinder;
import static io.airlift.json.JsonBinder.jsonBinder;
import static io.airlift.json.JsonCodecBinder.jsonCodecBinder;
import static io.airlift.units.DataSize.Unit.MEGABYTE;
import static io.trino.execution.scheduler.NodeSchedulerConfig.NodeSchedulerPolicy.TOPOLOGY;
import static io.trino.execution.scheduler.NodeSchedulerConfig.NodeSchedulerPolicy.UNIFORM;
import static io.trino.operator.RetryPolicy.TASK;
import static io.trino.server.InternalCommunicationHttpClientModule.internalHttpClientModule;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.Executors.newCachedThreadPool;
import static java.util.concurrent.Executors.newScheduledThreadPool;
Expand Down Expand Up @@ -255,12 +255,12 @@ protected void setup(Binder binder)
binder.bind(DiscoveryNodeManager.class).in(Scopes.SINGLETON);
binder.bind(InternalNodeManager.class).to(DiscoveryNodeManager.class).in(Scopes.SINGLETON);
newExporter(binder).export(DiscoveryNodeManager.class).withGeneratedName();
httpClientBinder(binder).bindHttpClient("node-manager", ForNodeManager.class)
install(internalHttpClientModule("node-manager", ForNodeManager.class)
.withTracing()
.withConfigDefaults(config -> {
config.setIdleTimeout(new Duration(30, SECONDS));
config.setRequestTimeout(new Duration(10, SECONDS));
});
}).build());

// node scheduler
// TODO: remove from NodePartitioningManager and move to CoordinatorModule
Expand Down Expand Up @@ -339,15 +339,15 @@ protected void setup(Binder binder)

// exchange client
binder.bind(DirectExchangeClientSupplier.class).to(DirectExchangeClientFactory.class).in(Scopes.SINGLETON);
httpClientBinder(binder).bindHttpClient("exchange", ForExchange.class)
install(internalHttpClientModule("exchange", ForExchange.class)
.withTracing()
.withFilter(GenerateTraceTokenRequestFilter.class)
.withConfigDefaults(config -> {
config.setIdleTimeout(new Duration(30, SECONDS));
config.setRequestTimeout(new Duration(10, SECONDS));
config.setMaxConnectionsPerServer(250);
config.setMaxContentLength(DataSize.of(32, MEGABYTE));
});
}).build());

configBinder(binder).bindConfig(DirectExchangeClientConfig.class);
binder.bind(ExchangeExecutionMBean.class).in(Scopes.SINGLETON);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,7 @@ private static class JwkModule
@Override
public void configure(Binder binder)
{
httpClientBinder(binder)
.bindHttpClient("jwk", ForJwt.class)
// Reset HttpClient default configuration to override InternalCommunicationModule changes.
// Setting a keystore and/or a truststore for internal communication changes the default SSL configuration
// for all clients in the same guice context. This, however, does not make sense for this client which will
// very rarely use the same SSL setup as internal communication, so using the system default truststore
// makes more sense.
.withConfigDefaults(config -> config
.setKeyStorePath(null)
.setKeyStorePassword(null)
.setTrustStorePath(null)
.setTrustStorePassword(null)
.setAutomaticHttpsSharedSecret(null));
httpClientBinder(binder).bindHttpClient("jwk", ForJwt.class);
}

@Provides
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,7 @@ protected void setup(Binder binder)
.in(Scopes.SINGLETON);
install(conditionalModule(OAuth2Config.class, OAuth2Config::isEnableDiscovery, this::bindOidcDiscovery, this::bindStaticConfiguration));
install(conditionalModule(OAuth2Config.class, OAuth2Config::isEnableRefreshTokens, this::enableRefreshTokens, this::disableRefreshTokens));
httpClientBinder(binder)
.bindHttpClient("oauth2-jwk", ForOAuth2.class)
// Reset to defaults to override InternalCommunicationModule changes to this client default configuration.
// Setting a keystore and/or a truststore for internal communication changes the default SSL configuration
// for all clients in this guice context. This does not make sense for this client which will very rarely
// use the same SSL configuration, so using the system default truststore makes more sense.
.withConfigDefaults(config -> config
.setKeyStorePath(null)
.setKeyStorePassword(null)
.setTrustStorePath(null)
.setTrustStorePassword(null)
.setAutomaticHttpsSharedSecret(null));
httpClientBinder(binder).bindHttpClient("oauth2-jwk", ForOAuth2.class);
}

private void enableRefreshTokens(Binder binder)
Expand Down
Loading

0 comments on commit 6c8dd6e

Please sign in to comment.