Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Listeners > BootSSLContext: improve locking access to cached sslConte… #295

Merged
merged 3 commits into from
Jun 7, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 85 additions & 50 deletions carapace-server/src/main/java/org/carapaceproxy/server/Listeners.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.SslProvider;
import io.netty.util.AsyncMapping;
import io.netty.util.Attribute;
import io.netty.util.AttributeKey;
import io.netty.util.concurrent.Promise;
Expand Down Expand Up @@ -72,6 +71,9 @@
import io.netty.handler.timeout.IdleStateHandler;
import static org.carapaceproxy.utils.CertificatesUtils.loadKeyStoreFromFile;
import io.netty.handler.ssl.OpenSslCachingX509KeyManagerFactory;
import io.netty.util.concurrent.Future;
import java.util.HashMap;
import java.util.Map.Entry;

/**
*
Expand All @@ -90,7 +92,7 @@ public class Listeners {
private final EventLoopGroup workerGroup;
private final HttpProxyServer parent;
private final Map<String, SslContext> sslContexts = new ConcurrentHashMap<>();
private final Map<HostPort, Channel> listeningChannels = new ConcurrentHashMap<>();
private final Map<HostPort, ListeningChannel> listeningChannels = new ConcurrentHashMap<>();
private final Map<HostPort, ClientConnectionHandler> listenersHandlers = new ConcurrentHashMap<>();
private final File basePath;
private boolean started;
Expand Down Expand Up @@ -177,34 +179,26 @@ private SslContext bootSslContext(NetworkListenerConfiguration listener, SSLCert
}
}

private void bootListener(NetworkListenerConfiguration listener) throws InterruptedException {
int port = listener.getPort() + parent.getListenersOffsetPort();
LOG.log(Level.INFO, "Starting listener at {0}:{1} ssl:{2}", new Object[]{listener.getHost(), port, listener.isSsl()});
private void bootListener(NetworkListenerConfiguration config) throws InterruptedException {
int port = config.getPort() + parent.getListenersOffsetPort();
LOG.log(Level.INFO, "Starting listener at {0}:{1} ssl:{2}", new Object[]{config.getHost(), port, config.isSsl()});

AsyncMapping<String, SslContext> sniMappings = (String sniHostname, Promise<SslContext> promise) -> {
try {
SslContext sslContext = resolveSslContext(listener, sniHostname);
return promise.setSuccess(sslContext);
} catch (ConfigurationNotValidException err) {
LOG.log(Level.SEVERE, "Error booting certificate for SNI hostname {0}, on listener {1}", new Object[]{sniHostname, listener});
return promise.setFailure(err);
}
};
ListeningChannel listeningChannel = new ListeningChannel(config, port);

HostPort key = new HostPort(listener.getHost(), port);
HostPort key = new HostPort(config.getHost(), port);
ServerBootstrap b = new ServerBootstrap();
b.group(bossGroup, workerGroup)
.channel(Epoll.isAvailable() ? EpollServerSocketChannel.class : NioServerSocketChannel.class)
.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(SocketChannel channel) throws Exception {
CURRENT_CONNECTED_CLIENTS_GAUGE.inc();
if (listener.isSsl()) {
SniHandler sni = new SniHandler(sniMappings) {
if (config.isSsl()) {
SniHandler sni = new SniHandler(listeningChannel) {
@Override
protected SslHandler newSslHandler(SslContext context, ByteBufAllocator allocator) {
SslHandler handler = super.newSslHandler(context, allocator);
if (listener.isOcsp() && OpenSsl.isOcspSupported()) {
if (config.isOcsp() && OpenSsl.isOcspSupported()) {
Certificate cert = (Certificate) context.attributes().attr(AttributeKey.valueOf(OCSP_CERTIFICATE_CHAIN)).get();
if (cert != null) {
try {
Expand Down Expand Up @@ -233,9 +227,9 @@ protected SslHandler newSslHandler(SslContext context, ByteBufAllocator allocato
parent.getBackendHealthManager(),
parent.getRequestsLogger(),
parent.getFullHttpMessageLogger(),
listener.getHost(),
config.getHost(),
port,
listener.isSsl()
config.isSsl()
);
channel.pipeline().addLast(connHandler);
parent.getFullHttpMessageLogger().attachHandler(channel, connHandler.getPendingRequest());
Expand All @@ -245,16 +239,16 @@ protected SslHandler newSslHandler(SslContext context, ByteBufAllocator allocato
})
.option(ChannelOption.SO_BACKLOG, 128)
.childOption(ChannelOption.SO_KEEPALIVE, true);
Channel channel = b.bind(listener.getHost(), port).sync().channel();

listeningChannels.put(key, channel);
Channel channel = b.bind(config.getHost(), port).sync().channel();
listeningChannel.setChannel(channel);
listeningChannels.put(key, listeningChannel);
LOG.log(Level.INFO, "started listener at {0}: {1}", new Object[]{key, channel});

}

public int getLocalPort() {
for (Channel c : listeningChannels.values()) {
InetSocketAddress addr = (InetSocketAddress) c.localAddress();
for (ListeningChannel c : listeningChannels.values()) {
InetSocketAddress addr = (InetSocketAddress) c.getChannel().localAddress();
return addr.getPort();
}
return -1;
Expand Down Expand Up @@ -282,33 +276,72 @@ public void stop() {
}
}

private SslContext resolveSslContext(NetworkListenerConfiguration listener, String sniHostname) throws ConfigurationNotValidException {
int port = listener.getPort() + parent.getListenersOffsetPort();
String key = listener.getHost() + ":" + port + "+" + sniHostname;
if (LOG.isLoggable(Level.FINER)) {
LOG.log(Level.FINER, "resolve SNI mapping " + sniHostname + ", key: " + key);
private final class ListeningChannel implements io.netty.util.AsyncMapping<String, SslContext> {

private final NetworkListenerConfiguration config;
private final int port;
private final Map<String, SslContext> listenerSslContexts = new HashMap<>();
Channel channel;

public ListeningChannel(NetworkListenerConfiguration config, int port) {
this.config = config;
this.port = port;
}
try {
return sslContexts.computeIfAbsent(key, (k) -> {

public Channel getChannel() {
return channel;
}

public void setChannel(Channel channel) {
this.channel = channel;
}

@Override
public Future<SslContext> map(String sniHostname, Promise<SslContext> promise) {
try {
String key = config.getHost() + ":" + port + "+" + sniHostname;
if (LOG.isLoggable(Level.FINER)) {
LOG.log(Level.FINER, "resolve SNI mapping " + sniHostname + ", key: " + key);
}
try {
SSLCertificateConfiguration choosen = chooseCertificate(sniHostname, listener.getDefaultCertificate());
if (choosen == null) {
throw new ConfigurationNotValidException("cannot find a certificate for snihostname " + sniHostname
+ ", with default cert for listener as '" + listener.getDefaultCertificate()
+ "', available " + currentConfiguration.getCertificates().keySet());
SslContext sslContext = listenerSslContexts.get(key);
if (sslContext != null) {
return promise.setSuccess(sslContext);
}

sslContext = sslContexts.computeIfAbsent(key, (k) -> {
try {
SSLCertificateConfiguration choosen = chooseCertificate(sniHostname, config.getDefaultCertificate());
if (choosen == null) {
throw new ConfigurationNotValidException("cannot find a certificate for snihostname " + sniHostname
+ ", with default cert for listener as '" + config.getDefaultCertificate()
+ "', available " + currentConfiguration.getCertificates().keySet());
}
return bootSslContext(config, choosen);
} catch (ConfigurationNotValidException ex) {
throw new RuntimeException(ex);
}
});
listenerSslContexts.put(key, sslContext);

return promise.setSuccess(sslContext);
} catch (RuntimeException err) {
if (err.getCause() instanceof ConfigurationNotValidException) {
throw (ConfigurationNotValidException) err.getCause();
} else {
throw new ConfigurationNotValidException(err);
}
return bootSslContext(listener, choosen);
} catch (ConfigurationNotValidException err) {
throw new RuntimeException(err);
}
});
} catch (RuntimeException err) {
if (err.getCause() instanceof ConfigurationNotValidException) {
throw (ConfigurationNotValidException) err.getCause();
} else {
throw new ConfigurationNotValidException(err);
} catch (ConfigurationNotValidException err) {
LOG.log(Level.SEVERE, "Error booting certificate for SNI hostname {0}, on listener {1}", new Object[]{sniHostname, config});
return promise.setFailure(err);
}
}

public void clear() {
this.listenerSslContexts.clear();
}

}

@VisibleForTesting
Expand Down Expand Up @@ -364,9 +397,9 @@ void reloadConfiguration(RuntimeServerConfiguration newConfiguration) throws Int
List<HostPort> listenersToStop = new ArrayList<>();
List<HostPort> listenersToStart = new ArrayList<>();
List<HostPort> listenersToRestart = new ArrayList<>();
for (HostPort key : listeningChannels.keySet()) {
for (Entry<HostPort, ListeningChannel> channel : listeningChannels.entrySet()) {
HostPort key = channel.getKey();
NetworkListenerConfiguration actualListenerConfig = currentConfiguration.getListener(key);

NetworkListenerConfiguration newConfigurationForListener = newConfiguration.getListener(key);
if (newConfigurationForListener == null) {
LOG.log(Level.INFO, "listener: {0} is to be shut down", key);
Expand All @@ -375,6 +408,8 @@ void reloadConfiguration(RuntimeServerConfiguration newConfiguration) throws Int
LOG.log(Level.INFO, "listener: {0} is to be restarted", key);
listenersToRestart.add(key);
}

channel.getValue().clear();
}
for (NetworkListenerConfiguration config : newConfiguration.getListeners()) {
HostPort key = config.getKey();
Expand Down Expand Up @@ -413,9 +448,9 @@ void reloadConfiguration(RuntimeServerConfiguration newConfiguration) throws Int
}

private void stopListener(HostPort hostport) throws InterruptedException {
Channel channel = listeningChannels.remove(hostport);
ListeningChannel channel = listeningChannels.remove(hostport);
if (channel != null) {
channel.close().sync();
channel.channel.close().sync();
}
listenersHandlers.remove(hostport);
}
Expand Down