Skip to content

Commit

Permalink
Fix NullPointerException in TlsMetricsHandler when used together with…
Browse files Browse the repository at this point in the history
… SniHandler (#3023)

Fixes #3022
  • Loading branch information
violetagg authored Jan 9, 2024
1 parent e0e4c5d commit 8171384
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.channel.socket.DatagramPacket;
import io.netty.handler.ssl.AbstractSniHandler;
import io.netty.handler.ssl.SslHandler;
import reactor.netty.NettyPipeline;
import reactor.util.Logger;
import reactor.util.Loggers;
Expand Down Expand Up @@ -95,12 +97,19 @@ public void channelRegistered(ChannelHandlerContext ctx) {
NettyPipeline.ConnectMetricsHandler,
connectMetricsHandler());
}
if (ctx.pipeline().get(NettyPipeline.SslHandler) != null) {
ChannelHandler sslHandler = ctx.pipeline().get(NettyPipeline.SslHandler);
if (sslHandler instanceof SslHandler) {
ctx.pipeline()
.addBefore(NettyPipeline.SslHandler,
NettyPipeline.TlsMetricsHandler,
tlsMetricsHandler());
}
else if (sslHandler instanceof AbstractSniHandler) {
ctx.pipeline()
.addAfter(NettyPipeline.SslHandler,
NettyPipeline.TlsMetricsHandler,
tlsMetricsHandler());
}

ctx.fireChannelRegistered();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
import io.netty.handler.ssl.SniCompletionEvent;
import io.netty.handler.ssl.SslHandler;
import reactor.util.annotation.Nullable;

Expand Down Expand Up @@ -86,28 +87,46 @@ static class TlsMetricsHandler extends ChannelInboundHandlerAdapter {

protected final ChannelMetricsRecorder recorder;

boolean listenerAdded;

TlsMetricsHandler(ChannelMetricsRecorder recorder) {
this.recorder = recorder;
}

@Override
public void channelActive(ChannelHandlerContext ctx) {
long tlsHandshakeTimeStart = System.nanoTime();
ctx.pipeline()
.get(SslHandler.class)
.handshakeFuture()
.addListener(f -> {
ctx.pipeline().remove(this);
recordTlsHandshakeTime(ctx, tlsHandshakeTimeStart, f.isSuccess() ? SUCCESS : ERROR);
});
addListener(ctx);
ctx.fireChannelActive();
}

@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
if (evt instanceof SniCompletionEvent) {
addListener(ctx);
}
ctx.fireUserEventTriggered(evt);
}

protected void recordTlsHandshakeTime(ChannelHandlerContext ctx, long tlsHandshakeTimeStart, String status) {
recorder.recordTlsHandshakeTime(
ctx.channel().remoteAddress(),
Duration.ofNanos(System.nanoTime() - tlsHandshakeTimeStart),
status);
}

private void addListener(ChannelHandlerContext ctx) {
if (!listenerAdded) {
SslHandler sslHandler = ctx.pipeline().get(SslHandler.class);
if (sslHandler != null) {
listenerAdded = true;
long tlsHandshakeTimeStart = System.nanoTime();
sslHandler.handshakeFuture()
.addListener(f -> {
ctx.pipeline().remove(this);
recordTlsHandshakeTime(ctx, tlsHandshakeTimeStart, f.isSuccess() ? SUCCESS : ERROR);
});
}
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2011-2023 VMware, Inc. or its affiliates, All Rights Reserved.
* Copyright (c) 2011-2024 VMware, Inc. or its affiliates, All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -136,6 +136,7 @@
import reactor.netty.http.Http2SslContextSpec;
import reactor.netty.http.HttpProtocol;
import reactor.netty.http.client.HttpClient;
import reactor.netty.http.client.HttpClientMetricsRecorder;
import reactor.netty.http.client.HttpClientRequest;
import reactor.netty.http.client.PrematureCloseException;
import reactor.netty.http.logging.ReactorNettyHttpMessageLogFactory;
Expand Down Expand Up @@ -2124,6 +2125,21 @@ void testHang() {

@Test
void testSniSupport() throws Exception {
doTestSniSupport(Function.identity(), Function.identity());
}

@Test
void testIssue3022() throws Exception {
TestHttpClientMetricsRecorder clientMetricsRecorder = new TestHttpClientMetricsRecorder();
TestHttpServerMetricsRecorder serverMetricsRecorder = new TestHttpServerMetricsRecorder();
doTestSniSupport(server -> server.metrics(true, () -> serverMetricsRecorder, Function.identity()),
client -> client.metrics(true, () -> clientMetricsRecorder, Function.identity()));
assertThat(clientMetricsRecorder.tlsHandshakeTime).isNotNull().isGreaterThan(Duration.ZERO);
assertThat(serverMetricsRecorder.tlsHandshakeTime).isNotNull().isGreaterThan(Duration.ZERO);
}

private void doTestSniSupport(Function<HttpServer, HttpServer> serverCustomizer,
Function<HttpClient, HttpClient> clientCustomizer) throws Exception {
SelfSignedCertificate defaultCert = new SelfSignedCertificate("default");
Http11SslContextSpec defaultSslContextBuilder =
Http11SslContextSpec.forServer(defaultCert.certificate(), defaultCert.privateKey());
Expand All @@ -2138,7 +2154,7 @@ void testSniSupport() throws Exception {

AtomicReference<String> hostname = new AtomicReference<>();
disposableServer =
createServer()
serverCustomizer.apply(createServer())
.secure(spec -> spec.sslContext(defaultSslContextBuilder)
.addSniMapping("*.test.com", domainSpec -> domainSpec.sslContext(testSslContextBuilder)))
.doOnChannelInit((obs, channel, remoteAddress) ->
Expand All @@ -2155,7 +2171,7 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
.handle((req, res) -> res.sendString(Mono.just("testSniSupport")))
.bindNow();

createClient(disposableServer::address)
clientCustomizer.apply(createClient(disposableServer::address))
.secure(spec -> spec.sslContext(clientSslContextBuilder)
.serverNames(new SNIHostName("test.com")))
.get()
Expand Down Expand Up @@ -3569,4 +3585,112 @@ private void testIssue2927(Function<HttpServer, HttpServer> serverCustomizer, Fu
.expectErrorMatches(t -> t instanceof PrematureCloseException && t.getCause() instanceof Http2Exception.HeaderListSizeException)
.verify(Duration.ofSeconds(30));
}

static final class TestHttpServerMetricsRecorder implements HttpServerMetricsRecorder {

Duration tlsHandshakeTime;

@Override
public void recordDataReceived(SocketAddress remoteAddress, long bytes) {
}

@Override
public void recordDataSent(SocketAddress remoteAddress, long bytes) {
}

@Override
public void incrementErrorsCount(SocketAddress remoteAddress) {
}

@Override
public void recordTlsHandshakeTime(SocketAddress remoteAddress, Duration time, String status) {
tlsHandshakeTime = time;
}

@Override
public void recordConnectTime(SocketAddress remoteAddress, Duration time, String status) {
}

@Override
public void recordResolveAddressTime(SocketAddress remoteAddress, Duration time, String status) {
}

@Override
public void recordDataReceived(SocketAddress remoteAddress, String uri, long bytes) {
}

@Override
public void recordDataSent(SocketAddress remoteAddress, String uri, long bytes) {
}

@Override
public void incrementErrorsCount(SocketAddress remoteAddress, String uri) {
}

@Override
public void recordDataReceivedTime(String uri, String method, Duration time) {
}

@Override
public void recordDataSentTime(String uri, String method, String status, Duration time) {
}

@Override
public void recordResponseTime(String uri, String method, String status, Duration time) {
}
}

static final class TestHttpClientMetricsRecorder implements HttpClientMetricsRecorder {

Duration tlsHandshakeTime;

@Override
public void recordDataReceived(SocketAddress remoteAddress, long bytes) {
}

@Override
public void recordDataSent(SocketAddress remoteAddress, long bytes) {
}

@Override
public void incrementErrorsCount(SocketAddress remoteAddress) {
}

@Override
public void recordTlsHandshakeTime(SocketAddress remoteAddress, Duration time, String status) {
tlsHandshakeTime = time;
}

@Override
public void recordConnectTime(SocketAddress remoteAddress, Duration time, String status) {
}

@Override
public void recordResolveAddressTime(SocketAddress remoteAddress, Duration time, String status) {
}

@Override
public void recordDataReceived(SocketAddress remoteAddress, String uri, long bytes) {
}

@Override
public void recordDataSent(SocketAddress remoteAddress, String uri, long bytes) {
}

@Override
public void incrementErrorsCount(SocketAddress remoteAddress, String uri) {
}

@Override
public void recordDataReceivedTime(SocketAddress remoteAddress, String uri, String method, String status, Duration time) {
}

@Override
public void recordDataSentTime(SocketAddress remoteAddress, String uri, String method, Duration time) {
}

@Override
public void recordResponseTime(SocketAddress remoteAddress, String uri, String method, String status, Duration time) {
}
}
}

0 comments on commit 8171384

Please sign in to comment.