From e37cbddc2fdfb9ad2b12ea27425502a06c03b448 Mon Sep 17 00:00:00 2001 From: ZhangJian He Date: Thu, 25 Jul 2024 19:45:29 +0800 Subject: [PATCH] feat: support config AddressResolverGroup in r2dbc-mysql (#279) Motivation: Currently,`AddressResolverGroup` can't be configured. The DnsResolver default start address listen to "0.0.0.0", which may have some security risks. also see https://github.com/netty/netty/pull/11061 Modification: Add `AddressResolverGroup` in Client's connect method --------- Signed-off-by: ZhangJian He --- .../mysql/MySqlConnectionConfiguration.java | 83 ++++++++++++------- .../r2dbc/mysql/MySqlConnectionFactory.java | 3 +- .../mysql/MySqlConnectionFactoryProvider.java | 14 ++++ .../io/asyncer/r2dbc/mysql/client/Client.java | 7 +- .../MySqlConnectionConfigurationTest.java | 15 ++++ .../MySqlConnectionFactoryProviderTest.java | 16 ++++ 6 files changed, 104 insertions(+), 34 deletions(-) diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionConfiguration.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionConfiguration.java index 3856b58bd..2f1c75961 100644 --- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionConfiguration.java +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionConfiguration.java @@ -22,6 +22,7 @@ import io.asyncer.r2dbc.mysql.extension.Extension; import io.asyncer.r2dbc.mysql.internal.util.InternalArrays; import io.netty.handler.ssl.SslContextBuilder; +import io.netty.resolver.AddressResolverGroup; import org.jetbrains.annotations.Nullable; import org.reactivestreams.Publisher; import reactor.netty.resources.LoopResources; @@ -127,6 +128,9 @@ public final class MySqlConnectionConfiguration { @Nullable private final Publisher passwordPublisher; + @Nullable + private final AddressResolverGroup resolver; + private MySqlConnectionConfiguration( boolean isHost, String domain, int port, MySqlSslConfiguration ssl, boolean tcpKeepAlive, boolean tcpNoDelay, @Nullable Duration connectTimeout, @@ -141,7 +145,8 @@ private MySqlConnectionConfiguration( int queryCacheSize, int prepareCacheSize, Set compressionAlgorithms, int zstdCompressionLevel, @Nullable LoopResources loopResources, - Extensions extensions, @Nullable Publisher passwordPublisher + Extensions extensions, @Nullable Publisher passwordPublisher, + @Nullable AddressResolverGroup resolver ) { this.isHost = isHost; this.domain = domain; @@ -171,6 +176,7 @@ private MySqlConnectionConfiguration( this.loopResources = loopResources == null ? TcpResources.get() : loopResources; this.extensions = extensions; this.passwordPublisher = passwordPublisher; + this.resolver = resolver; } /** @@ -301,6 +307,11 @@ Publisher getPasswordPublisher() { return passwordPublisher; } + @Nullable + AddressResolverGroup getResolver() { + return resolver; + } + @Override public boolean equals(Object o) { if (this == o) { @@ -337,7 +348,8 @@ public boolean equals(Object o) { zstdCompressionLevel == that.zstdCompressionLevel && Objects.equals(loopResources, that.loopResources) && extensions.equals(that.extensions) && - Objects.equals(passwordPublisher, that.passwordPublisher); + Objects.equals(passwordPublisher, that.passwordPublisher) && + Objects.equals(resolver, that.resolver); } @Override @@ -352,19 +364,26 @@ public int hashCode() { loadLocalInfilePath, localInfileBufferSize, queryCacheSize, prepareCacheSize, compressionAlgorithms, zstdCompressionLevel, - loopResources, extensions, passwordPublisher); + loopResources, extensions, passwordPublisher, resolver); } @Override public String toString() { - if (isHost) { - return "MySqlConnectionConfiguration{host='" + domain + "', port=" + port + ", ssl=" + ssl + - ", tcpNoDelay=" + tcpNoDelay + ", tcpKeepAlive=" + tcpKeepAlive + - ", connectTimeout=" + connectTimeout + + return "MySqlConnectionConfiguration{" + + (isHost ? "host='" + domain + "', port=" + port + ", ssl=" + ssl + + ", tcpNoDelay=" + tcpNoDelay + ", tcpKeepAlive=" + tcpKeepAlive : + "unixSocket='" + domain + "'") + + buildCommonToStringPart() + + '}'; + } + + private String buildCommonToStringPart() { + return ", connectTimeout=" + connectTimeout + ", preserveInstants=" + preserveInstants + ", connectionTimeZone=" + connectionTimeZone + ", forceConnectionTimeZoneToSession=" + forceConnectionTimeZoneToSession + - ", zeroDateOption=" + zeroDateOption + ", user='" + user + "', password=" + password + + ", zeroDateOption=" + zeroDateOption + + ", user='" + user + "', password=" + password + ", database='" + database + "', createDatabaseIfNotExist=" + createDatabaseIfNotExist + ", preferPrepareStatement=" + preferPrepareStatement + ", sessionVariables=" + sessionVariables + @@ -372,32 +391,14 @@ public String toString() { ", statementTimeout=" + statementTimeout + ", loadLocalInfilePath=" + loadLocalInfilePath + ", localInfileBufferSize=" + localInfileBufferSize + - ", queryCacheSize=" + queryCacheSize + ", prepareCacheSize=" + prepareCacheSize + + ", queryCacheSize=" + queryCacheSize + + ", prepareCacheSize=" + prepareCacheSize + ", compressionAlgorithms=" + compressionAlgorithms + ", zstdCompressionLevel=" + zstdCompressionLevel + ", loopResources=" + loopResources + - ", extensions=" + extensions + ", passwordPublisher=" + passwordPublisher + '}'; - } - - return "MySqlConnectionConfiguration{unixSocket='" + domain + - "', connectTimeout=" + connectTimeout + - ", preserveInstants=" + preserveInstants + - ", connectionTimeZone=" + connectionTimeZone + - ", forceConnectionTimeZoneToSession=" + forceConnectionTimeZoneToSession + - ", zeroDateOption=" + zeroDateOption + ", user='" + user + "', password=" + password + - ", database='" + database + "', createDatabaseIfNotExist=" + createDatabaseIfNotExist + - ", preferPrepareStatement=" + preferPrepareStatement + - ", sessionVariables=" + sessionVariables + - ", lockWaitTimeout=" + lockWaitTimeout + - ", statementTimeout=" + statementTimeout + - ", loadLocalInfilePath=" + loadLocalInfilePath + - ", localInfileBufferSize=" + localInfileBufferSize + - ", queryCacheSize=" + queryCacheSize + - ", prepareCacheSize=" + prepareCacheSize + - ", compressionAlgorithms=" + compressionAlgorithms + - ", zstdCompressionLevel=" + zstdCompressionLevel + - ", loopResources=" + loopResources + - ", extensions=" + extensions + ", passwordPublisher=" + passwordPublisher + '}'; + ", extensions=" + extensions + + ", passwordPublisher=" + passwordPublisher + + ", resolver=" + resolver; } /** @@ -494,6 +495,9 @@ public static final class Builder { @Nullable private Publisher passwordPublisher; + @Nullable + private AddressResolverGroup resolver; + /** * Builds an immutable {@link MySqlConnectionConfiguration} with current options. * @@ -528,7 +532,7 @@ public MySqlConnectionConfiguration build() { loadLocalInfilePath, localInfileBufferSize, queryCacheSize, prepareCacheSize, compressionAlgorithms, zstdCompressionLevel, loopResources, - Extensions.from(extensions, autodetectExtensions), passwordPublisher); + Extensions.from(extensions, autodetectExtensions), passwordPublisher, resolver); } /** @@ -1156,6 +1160,21 @@ public Builder passwordPublisher(Publisher passwordPublisher) { return this; } + /** + * Sets the {@link AddressResolverGroup} for resolving host addresses. + *

+ * This can be used to customize the DNS resolution mechanism, which is particularly useful in environments + * with specific DNS configuration needs or where a custom DNS resolver is required. + * + * @param resolver the resolver group to use for host address resolution. + * @return this {@link Builder}. + * @since 1.2.0 + */ + public Builder resolver(AddressResolverGroup resolver) { + this.resolver = resolver; + return this; + } + private SslMode requireSslMode() { SslMode sslMode = this.sslMode; diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactory.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactory.java index d003db2b0..bff85c809 100644 --- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactory.java +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactory.java @@ -147,7 +147,8 @@ private static Mono getMySqlConnection( configuration.isTcpNoDelay(), context, configuration.getConnectTimeout(), - configuration.getLoopResources() + configuration.getLoopResources(), + configuration.getResolver() )).flatMap(client -> { // Lazy init database after handshake/login boolean deferDatabase = configuration.isCreateDatabaseIfNotExist(); diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactoryProvider.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactoryProvider.java index f6dc1a57a..d89005394 100644 --- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactoryProvider.java +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactoryProvider.java @@ -20,6 +20,7 @@ import io.asyncer.r2dbc.mysql.constant.SslMode; import io.asyncer.r2dbc.mysql.constant.ZeroDateOption; import io.netty.handler.ssl.SslContextBuilder; +import io.netty.resolver.AddressResolverGroup; import io.r2dbc.spi.ConnectionFactory; import io.r2dbc.spi.ConnectionFactoryOptions; import io.r2dbc.spi.ConnectionFactoryProvider; @@ -308,6 +309,17 @@ public final class MySqlConnectionFactoryProvider implements ConnectionFactoryPr */ public static final Option> PASSWORD_PUBLISHER = Option.valueOf("passwordPublisher"); + /** + * Option to set the {@link AddressResolverGroup} for resolving host addresses. + *

+ * This can be used to customize the DNS resolution mechanism, which is particularly useful in environments + * with specific DNS configuration needs or where a custom DNS resolver is required. + *

+ * + * @since 1.2.0 + */ + public static final Option> RESOLVER = Option.valueOf("resolver"); + @Override public ConnectionFactory create(ConnectionFactoryOptions options) { requireNonNull(options, "connectionFactoryOptions must not be null"); @@ -389,6 +401,8 @@ static MySqlConnectionConfiguration setup(ConnectionFactoryOptions options) { .to(builder::loopResources); mapper.optional(PASSWORD_PUBLISHER).as(Publisher.class) .to(builder::passwordPublisher); + mapper.optional(RESOLVER).as(AddressResolverGroup.class) + .to(builder::resolver); mapper.optional(SESSION_VARIABLES).asArray( String[].class, Function.identity(), diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/client/Client.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/client/Client.java index d7c3ac28a..0beaf4c0d 100644 --- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/client/Client.java +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/client/Client.java @@ -22,6 +22,7 @@ import io.asyncer.r2dbc.mysql.message.server.ServerMessage; import io.netty.buffer.ByteBufAllocator; import io.netty.channel.ChannelOption; +import io.netty.resolver.AddressResolverGroup; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; import org.jetbrains.annotations.Nullable; @@ -132,7 +133,7 @@ public interface Client { */ static Mono connect(MySqlSslConfiguration ssl, SocketAddress address, boolean tcpKeepAlive, boolean tcpNoDelay, ConnectionContext context, @Nullable Duration connectTimeout, - LoopResources loopResources) { + LoopResources loopResources, @Nullable AddressResolverGroup resolver) { requireNonNull(ssl, "ssl must not be null"); requireNonNull(address, "address must not be null"); requireNonNull(context, "context must not be null"); @@ -150,6 +151,10 @@ static Mono connect(MySqlSslConfiguration ssl, SocketAddress address, bo tcpClient = tcpClient.option(ChannelOption.TCP_NODELAY, tcpNoDelay); } + if (resolver != null) { + tcpClient = tcpClient.resolver(resolver); + } + return tcpClient.remoteAddress(() -> address).connect() .map(conn -> new ReactorNettyClient(conn, ssl, context)); } diff --git a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlConnectionConfigurationTest.java b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlConnectionConfigurationTest.java index f050f4e4a..f05defb17 100644 --- a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlConnectionConfigurationTest.java +++ b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlConnectionConfigurationTest.java @@ -22,6 +22,8 @@ import io.asyncer.r2dbc.mysql.constant.ZeroDateOption; import io.asyncer.r2dbc.mysql.extension.Extension; import io.netty.handler.ssl.SslContextBuilder; +import io.netty.resolver.AddressResolverGroup; +import io.netty.resolver.DefaultAddressResolverGroup; import org.assertj.core.api.ObjectAssert; import org.assertj.core.api.ThrowableTypeAssert; import org.jetbrains.annotations.Nullable; @@ -207,6 +209,19 @@ void validPasswordSupplier() { .verifyComplete(); } + @Test + void validResolver() { + final AddressResolverGroup resolver = DefaultAddressResolverGroup.INSTANCE; + AddressResolverGroup resolverGroup = MySqlConnectionConfiguration.builder() + .host(HOST) + .user(USER) + .resolver(resolver) + .autodetectExtensions(false) + .build() + .getResolver(); + assertThat(resolverGroup).isSameAs(resolver); + } + private static MySqlConnectionConfiguration unixSocketSslMode(SslMode sslMode) { return MySqlConnectionConfiguration.builder() .unixSocket(UNIX_SOCKET) diff --git a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactoryProviderTest.java b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactoryProviderTest.java index ab75161c1..1e71a9f17 100644 --- a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactoryProviderTest.java +++ b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactoryProviderTest.java @@ -20,6 +20,8 @@ import io.asyncer.r2dbc.mysql.constant.SslMode; import io.asyncer.r2dbc.mysql.constant.ZeroDateOption; import io.netty.handler.ssl.SslContextBuilder; +import io.netty.resolver.AddressResolverGroup; +import io.netty.resolver.DefaultAddressResolverGroup; import io.r2dbc.spi.ConnectionFactories; import io.r2dbc.spi.ConnectionFactoryOptions; import io.r2dbc.spi.Option; @@ -50,6 +52,7 @@ import java.util.stream.Stream; import static io.asyncer.r2dbc.mysql.MySqlConnectionFactoryProvider.PASSWORD_PUBLISHER; +import static io.asyncer.r2dbc.mysql.MySqlConnectionFactoryProvider.RESOLVER; import static io.asyncer.r2dbc.mysql.MySqlConnectionFactoryProvider.USE_SERVER_PREPARE_STATEMENT; import static io.r2dbc.spi.ConnectionFactoryOptions.CONNECT_TIMEOUT; import static io.r2dbc.spi.ConnectionFactoryOptions.DATABASE; @@ -453,6 +456,19 @@ void validPasswordSupplier() { assertThat(ConnectionFactories.get(options)).isExactlyInstanceOf(MySqlConnectionFactory.class); } + @Test + void validResolver() { + final AddressResolverGroup resolver = DefaultAddressResolverGroup.INSTANCE; + ConnectionFactoryOptions options = ConnectionFactoryOptions.builder() + .option(DRIVER, "mysql") + .option(HOST, "127.0.0.1") + .option(USER, "root") + .option(RESOLVER, resolver) + .build(); + + assertThat(ConnectionFactories.get(options)).isExactlyInstanceOf(MySqlConnectionFactory.class); + } + @Test void allConfigurationOptions() { List exceptConfigs = Arrays.asList(