Skip to content

Commit

Permalink
Add SNI support for HTTP/3 (#3496)
Browse files Browse the repository at this point in the history
  • Loading branch information
violetagg authored Nov 5, 2024
1 parent e0c23c0 commit 126565d
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import reactor.netty.transport.logging.AdvancedByteBufFormat;
import reactor.util.Logger;
import reactor.util.Loggers;
import reactor.util.annotation.Incubating;
import reactor.util.annotation.Nullable;

import static reactor.netty.ReactorNetty.format;
Expand Down Expand Up @@ -285,6 +286,28 @@ public interface GenericSslContextSpec<B> {
SslContext sslContext() throws SSLException;
}

@Incubating
public interface GenericSslContextSpecWithSniSupport<B> extends GenericSslContextSpec<B> {

/**
* Configures the underlying {@link SslContext}.
*
* @param sslCtxBuilder a callback for configuring the underlying {@link SslContext}
* @return {@code this}
*/
@Override
GenericSslContextSpecWithSniSupport<B> configure(Consumer<B> sslCtxBuilder);

/**
* Create a new {@link SslContext} instance with the configured settings.
*
* @param sniMappings {@code SNI} configuration per domain
* @return a new {@link SslContext} instance
* @throws SSLException thrown when {@link SslContext} instance cannot be created
*/
SslContext sslContext(Map<String, SslProvider> sniMappings) throws SSLException;
}

/**
* SslContext builder that provides, specific for the protocol, default configuration.
* The default configuration is applied prior any other custom configuration.
Expand All @@ -305,13 +328,20 @@ public interface ProtocolSslContextSpec extends GenericSslContextSpec<SslContext
final int builderHashCode;
final SniProvider sniProvider;
final Map<String, SslProvider> confPerDomainName;
final List<SNIServerName> serverNames;
final AsyncMapping<String, SslProvider> sniMappings;

SslProvider(SslProvider.Build builder) {
this.confPerDomainName = builder.confPerDomainName;
if (builder.sslContext == null) {
if (builder.genericSslContextSpec != null) {
try {
this.sslContext = builder.genericSslContextSpec.sslContext();
if (!confPerDomainName.isEmpty() && builder.genericSslContextSpec instanceof GenericSslContextSpecWithSniSupport) {
this.sslContext = ((GenericSslContextSpecWithSniSupport<?>) builder.genericSslContextSpec).sslContext(confPerDomainName);
}
else {
this.sslContext = builder.genericSslContextSpec.sslContext();
}
}
catch (SSLException e) {
throw Exceptions.propagate(e);
Expand All @@ -324,12 +354,13 @@ public interface ProtocolSslContextSpec extends GenericSslContextSpec<SslContext
else {
this.sslContext = builder.sslContext;
}
if (builder.serverNames != null) {
this.serverNames = builder.serverNames;
if (serverNames != null) {
Consumer<SslHandler> configurator =
h -> {
SSLEngine engine = h.engine();
SSLParameters sslParameters = engine.getSSLParameters();
sslParameters.setServerNames(builder.serverNames);
sslParameters.setServerNames(serverNames);
engine.setSSLParameters(sslParameters);
};
this.handlerConfigurator = builder.handlerConfigurator == null ? configurator :
Expand All @@ -342,7 +373,6 @@ public interface ProtocolSslContextSpec extends GenericSslContextSpec<SslContext
this.closeNotifyFlushTimeoutMillis = builder.closeNotifyFlushTimeoutMillis;
this.closeNotifyReadTimeoutMillis = builder.closeNotifyReadTimeoutMillis;
this.builderHashCode = builder.hashCode();
this.confPerDomainName = builder.confPerDomainName;
this.sniMappings = builder.sniMappings;
if (!confPerDomainName.isEmpty()) {
this.sniProvider = new SniProvider(confPerDomainName, this);
Expand Down Expand Up @@ -371,6 +401,7 @@ else if (sniMappings != null) {
this.closeNotifyReadTimeoutMillis = from.closeNotifyReadTimeoutMillis;
this.builderHashCode = from.builderHashCode;
this.confPerDomainName = from.confPerDomainName;
this.serverNames = from.serverNames;
this.sniMappings = from.sniMappings;
this.sniProvider = from.sniProvider;
}
Expand All @@ -384,6 +415,12 @@ public SslContext getSslContext() {
return this.sslContext;
}

@Incubating
@Nullable
public List<SNIServerName> getServerNames() {
return serverNames;
}

public void configure(SslHandler sslHandler) {
Objects.requireNonNull(sslHandler, "sslHandler");
sslHandler.setHandshakeTimeoutMillis(handshakeTimeoutMillis);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.ssl.SniCompletionEvent;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import io.netty.incubator.codec.quic.InsecureQuicTokenHandler;
Expand Down Expand Up @@ -52,6 +55,7 @@
import reactor.util.annotation.Nullable;
import reactor.util.function.Tuple2;

import javax.net.ssl.SNIHostName;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.security.cert.CertificateException;
Expand Down Expand Up @@ -651,6 +655,51 @@ void testProtocolVersion() {
.verify(Duration.ofSeconds(5));
}

@Test
void testSniSupport() throws Exception {
SelfSignedCertificate defaultCert = new SelfSignedCertificate("default");
SelfSignedCertificate testCert = new SelfSignedCertificate("test.com");

AtomicReference<String> hostname = new AtomicReference<>();

Http3SslContextSpec defaultSslContextBuilder = Http3SslContextSpec.forServer(defaultCert.key(), null, defaultCert.cert());
Http3SslContextSpec testSslContextBuilder = Http3SslContextSpec.forServer(testCert.key(), null, testCert.cert());

disposableServer =
createServer().port(8080)
.secure(spec -> spec.sslContext(defaultSslContextBuilder)
.addSniMapping("*.test.com", domainSpec -> domainSpec.sslContext(testSslContextBuilder)))
.doOnChannelInit((obs, channel, remoteAddress) ->
channel.pipeline()
.addLast(new ChannelInboundHandlerAdapter() {
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
if (evt instanceof SniCompletionEvent) {
hostname.set(((SniCompletionEvent) evt).hostname());
}
ctx.fireUserEventTriggered(evt);
}
}))
.handle((req, res) -> res.sendString(Mono.just("testSniSupport")))
.bindNow();

Http3SslContextSpec clientSslContextBuilder =
Http3SslContextSpec.forClient()
.configure(builder -> builder.trustManager(InsecureTrustManagerFactory.INSTANCE));

createClient(disposableServer.port())
.secure(spec -> spec.sslContext(clientSslContextBuilder)
.serverNames(new SNIHostName("test.com")))
.get()
.uri("/")
.responseContent()
.aggregate()
.block(Duration.ofSeconds(30));

assertThat(hostname.get()).isNotNull();
assertThat(hostname.get()).isEqualTo("test.com");
}

@Test
void testTrailerHeadersChunkedResponse() {
disposableServer =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
package reactor.netty.http;

import io.netty.handler.ssl.SslContext;
import io.netty.incubator.codec.quic.QuicSslContext;
import io.netty.incubator.codec.quic.QuicSslContextBuilder;
import io.netty.util.DomainWildcardMappingBuilder;
import reactor.netty.tcp.SslProvider;
import reactor.util.annotation.Incubating;
import reactor.util.annotation.Nullable;
Expand All @@ -27,10 +29,12 @@
import java.io.File;
import java.security.PrivateKey;
import java.security.cert.X509Certificate;
import java.util.Map;
import java.util.Objects;
import java.util.function.Consumer;

import static io.netty.incubator.codec.http3.Http3.supportedApplicationProtocols;
import static io.netty.incubator.codec.quic.QuicSslContextBuilder.buildForServerWithSni;

/**
* SslContext builder that provides default configuration specific to HTTP/3 as follows:
Expand All @@ -44,7 +48,7 @@
* @see io.netty.incubator.codec.http3.Http3#supportedApplicationProtocols()
*/
@Incubating
public final class Http3SslContextSpec implements SslProvider.GenericSslContextSpec<QuicSslContextBuilder> {
public final class Http3SslContextSpec implements SslProvider.GenericSslContextSpecWithSniSupport<QuicSslContextBuilder> {

/**
* Creates a builder for new client-side {@link SslContext}.
Expand Down Expand Up @@ -103,6 +107,14 @@ public SslContext sslContext() throws SSLException {
return sslContextBuilder.build();
}

@Override
public SslContext sslContext(Map<String, SslProvider> sniMappings) throws SSLException {
DomainWildcardMappingBuilder<QuicSslContext> mappingsSslProviderBuilder =
new DomainWildcardMappingBuilder<>((QuicSslContext) sslContext());
sniMappings.forEach((s, sslProvider) -> mappingsSslProviderBuilder.add(s, (QuicSslContext) sslProvider.getSslContext()));
return buildForServerWithSni(mappingsSslProviderBuilder.build());
}

final QuicSslContextBuilder sslContextBuilder;

Http3SslContextSpec(QuicSslContextBuilder sslContextBuilder) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,20 @@
import io.netty.channel.ChannelInitializer;
import io.netty.incubator.codec.quic.QuicClientCodecBuilder;
import io.netty.incubator.codec.quic.QuicSslContext;
import io.netty.incubator.codec.quic.QuicSslEngine;
import reactor.netty.Connection;
import reactor.netty.ConnectionObserver;
import reactor.netty.NettyPipeline;
import reactor.netty.channel.ChannelOperations;
import reactor.netty.http.Http3SettingsSpec;
import reactor.netty.tcp.SslProvider;
import reactor.util.annotation.Nullable;

import javax.net.ssl.SNIHostName;
import javax.net.ssl.SNIServerName;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.TimeUnit;

import static io.netty.incubator.codec.http3.Http3.newQuicClientCodecBuilder;
Expand All @@ -39,25 +47,52 @@ final class Http3ChannelInitializer extends ChannelInitializer<Channel> {
final ConnectionObserver obs;
final ChannelOperations.OnSetup opsFactory;
final ChannelInitializer<Channel> quicChannelInitializer;
final QuicSslContext quicSslContext;
final SocketAddress remoteAddress;
final SslProvider sslProvider;

Http3ChannelInitializer(HttpClientConfig config, ChannelInitializer<Channel> quicChannelInitializer, ConnectionObserver obs) {
Http3ChannelInitializer(HttpClientConfig config, ChannelInitializer<Channel> quicChannelInitializer, ConnectionObserver obs,
@Nullable SocketAddress remoteAddress) {
this.http3Settings = config.http3SettingsSpec();
this.loggingHandler = config.loggingHandler();
this.obs = obs;
this.opsFactory = config.channelOperationsProvider();
this.quicChannelInitializer = quicChannelInitializer;
if (config.sslProvider.getSslContext() instanceof QuicSslContext) {
this.quicSslContext = (QuicSslContext) config.sslProvider.getSslContext();
}
else {
throw new IllegalArgumentException("The configured SslContext is not QuicSslContext");
}
this.remoteAddress = remoteAddress;
this.sslProvider = config.sslProvider;
}

@Override
protected void initChannel(Channel channel) {
QuicClientCodecBuilder quicClientCodecBuilder = newQuicClientCodecBuilder().sslContext(quicSslContext);
QuicClientCodecBuilder quicClientCodecBuilder = newQuicClientCodecBuilder();

quicClientCodecBuilder.sslEngineProvider(ch -> {
QuicSslContext quicSslContext;
if (sslProvider.getSslContext() instanceof QuicSslContext) {
quicSslContext = (QuicSslContext) sslProvider.getSslContext();
}
else {
throw new IllegalArgumentException("The configured SslContext is not QuicSslContext");
}

QuicSslEngine engine;
if (remoteAddress instanceof InetSocketAddress) {
InetSocketAddress sniInfo = (InetSocketAddress) remoteAddress;
if (sslProvider.getServerNames() != null && !sslProvider.getServerNames().isEmpty()) {
SNIServerName serverName = sslProvider.getServerNames().get(0);
String serverNameStr = serverName instanceof SNIHostName ? ((SNIHostName) serverName).getAsciiName() :
new String(serverName.getEncoded(), StandardCharsets.US_ASCII);
engine = quicSslContext.newEngine(ch.alloc(), serverNameStr, sniInfo.getPort());
}
else {
engine = quicSslContext.newEngine(ch.alloc(), sniInfo.getHostString(), sniInfo.getPort());
}
}
else {
engine = quicSslContext.newEngine(ch.alloc());
}

return engine;
});

if (http3Settings != null) {
quicClientCodecBuilder.initialMaxData(http3Settings.maxData())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ public WebsocketClientSpec websocketClientSpec() {
public ChannelInitializer<Channel> channelInitializer(ConnectionObserver connectionObserver,
@Nullable SocketAddress remoteAddress, boolean onServer) {
ChannelInitializer<Channel> channelInitializer = super.channelInitializer(connectionObserver, remoteAddress, onServer);
return (_protocols & h3) == h3 ? new Http3ChannelInitializer(this, channelInitializer, connectionObserver) : channelInitializer;
return (_protocols & h3) == h3 ? new Http3ChannelInitializer(this, channelInitializer, connectionObserver, remoteAddress) : channelInitializer;
}

/**
Expand Down

0 comments on commit 126565d

Please sign in to comment.