Skip to content

Commit

Permalink
feat: support config AddressResolverGroup in r2dbc-mysql
Browse files Browse the repository at this point in the history
Signed-off-by: ZhangJian He <[email protected]>
  • Loading branch information
hezhangjian committed Jul 25, 2024
1 parent 508d6c3 commit 8a5e9cc
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import io.asyncer.r2dbc.mysql.extension.Extension;
import io.asyncer.r2dbc.mysql.internal.util.InternalArrays;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.resolver.AddressResolverGroup;
import org.jetbrains.annotations.Nullable;
import org.reactivestreams.Publisher;
import reactor.netty.resources.LoopResources;
Expand Down Expand Up @@ -127,6 +128,8 @@ public final class MySqlConnectionConfiguration {
@Nullable
private final Publisher<String> passwordPublisher;

private final AddressResolverGroup<?> resolver;

private MySqlConnectionConfiguration(
boolean isHost, String domain, int port, MySqlSslConfiguration ssl,
boolean tcpKeepAlive, boolean tcpNoDelay, @Nullable Duration connectTimeout,
Expand All @@ -141,7 +144,8 @@ private MySqlConnectionConfiguration(
int queryCacheSize, int prepareCacheSize,
Set<CompressionAlgorithm> compressionAlgorithms, int zstdCompressionLevel,
@Nullable LoopResources loopResources,
Extensions extensions, @Nullable Publisher<String> passwordPublisher
Extensions extensions, @Nullable Publisher<String> passwordPublisher,
@Nullable AddressResolverGroup<?> resolver
) {
this.isHost = isHost;
this.domain = domain;
Expand Down Expand Up @@ -171,6 +175,7 @@ private MySqlConnectionConfiguration(
this.loopResources = loopResources == null ? TcpResources.get() : loopResources;
this.extensions = extensions;
this.passwordPublisher = passwordPublisher;
this.resolver = resolver;
}

/**
Expand Down Expand Up @@ -301,6 +306,11 @@ Publisher<String> getPasswordPublisher() {
return passwordPublisher;
}

@Nullable
AddressResolverGroup<?> getResolver() {
return resolver;
}

@Override
public boolean equals(Object o) {
if (this == o) {
Expand Down Expand Up @@ -494,6 +504,9 @@ public static final class Builder {
@Nullable
private Publisher<String> passwordPublisher;

@Nullable
private AddressResolverGroup<?> resolver;

/**
* Builds an immutable {@link MySqlConnectionConfiguration} with current options.
*
Expand Down Expand Up @@ -528,7 +541,7 @@ public MySqlConnectionConfiguration build() {
loadLocalInfilePath,
localInfileBufferSize, queryCacheSize, prepareCacheSize,
compressionAlgorithms, zstdCompressionLevel, loopResources,
Extensions.from(extensions, autodetectExtensions), passwordPublisher);
Extensions.from(extensions, autodetectExtensions), passwordPublisher, resolver);
}

/**
Expand Down Expand Up @@ -1156,6 +1169,11 @@ public Builder passwordPublisher(Publisher<String> passwordPublisher) {
return this;
}

public Builder resolver(AddressResolverGroup<?> resolver) {
this.resolver = resolver;
return this;
}

private SslMode requireSslMode() {
SslMode sslMode = this.sslMode;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ private static Mono<MySqlConnection> getMySqlConnection(
configuration.isTcpNoDelay(),
context,
configuration.getConnectTimeout(),
configuration.getLoopResources()
configuration.getLoopResources(),
configuration.getResolver()
)).flatMap(client -> {
// Lazy init database after handshake/login
boolean deferDatabase = configuration.isCreateDatabaseIfNotExist();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.asyncer.r2dbc.mysql.constant.SslMode;
import io.asyncer.r2dbc.mysql.constant.ZeroDateOption;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.resolver.AddressResolverGroup;
import io.r2dbc.spi.ConnectionFactory;
import io.r2dbc.spi.ConnectionFactoryOptions;
import io.r2dbc.spi.ConnectionFactoryProvider;
Expand Down Expand Up @@ -308,6 +309,17 @@ public final class MySqlConnectionFactoryProvider implements ConnectionFactoryPr
*/
public static final Option<Publisher<String>> PASSWORD_PUBLISHER = Option.valueOf("passwordPublisher");

/**
* Option to set the {@link AddressResolverGroup} for resolving host addresses.
* <p>
* This can be used to customize the DNS resolution mechanism, which is particularly useful in environments
* with specific DNS configuration needs or where a custom DNS resolver is required.
* <p>
*
* @since 1.2.0
*/
public static final Option<AddressResolverGroup<?>> RESOLVER = Option.valueOf("resolver");

@Override
public ConnectionFactory create(ConnectionFactoryOptions options) {
requireNonNull(options, "connectionFactoryOptions must not be null");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import io.asyncer.r2dbc.mysql.message.server.ServerMessage;
import io.netty.buffer.ByteBufAllocator;
import io.netty.channel.ChannelOption;
import io.netty.resolver.AddressResolverGroup;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
import org.jetbrains.annotations.Nullable;
Expand Down Expand Up @@ -132,7 +133,7 @@ public interface Client {
*/
static Mono<Client> connect(MySqlSslConfiguration ssl, SocketAddress address, boolean tcpKeepAlive,
boolean tcpNoDelay, ConnectionContext context, @Nullable Duration connectTimeout,
LoopResources loopResources) {
LoopResources loopResources, @Nullable AddressResolverGroup<?> resolverGroup) {
requireNonNull(ssl, "ssl must not be null");
requireNonNull(address, "address must not be null");
requireNonNull(context, "context must not be null");
Expand All @@ -150,6 +151,10 @@ static Mono<Client> connect(MySqlSslConfiguration ssl, SocketAddress address, bo
tcpClient = tcpClient.option(ChannelOption.TCP_NODELAY, tcpNoDelay);
}

if (resolverGroup != null) {
tcpClient = tcpClient.resolver(resolverGroup);
}

return tcpClient.remoteAddress(() -> address).connect()
.map(conn -> new ReactorNettyClient(conn, ssl, context));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import io.asyncer.r2dbc.mysql.constant.ZeroDateOption;
import io.asyncer.r2dbc.mysql.extension.Extension;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.resolver.AddressResolverGroup;
import io.netty.resolver.DefaultAddressResolverGroup;
import org.assertj.core.api.ObjectAssert;
import org.assertj.core.api.ThrowableTypeAssert;
import org.jetbrains.annotations.Nullable;
Expand Down Expand Up @@ -207,6 +209,19 @@ void validPasswordSupplier() {
.verifyComplete();
}

@Test
void validResolver() {
final AddressResolverGroup<?> resolver = DefaultAddressResolverGroup.INSTANCE;
AddressResolverGroup<?> resolverGroup = MySqlConnectionConfiguration.builder()
.host(HOST)
.user(USER)
.resolver(resolver)
.autodetectExtensions(false)
.build()
.getResolver();
assertThat(resolverGroup).isSameAs(resolver);
}

private static MySqlConnectionConfiguration unixSocketSslMode(SslMode sslMode) {
return MySqlConnectionConfiguration.builder()
.unixSocket(UNIX_SOCKET)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import io.asyncer.r2dbc.mysql.constant.SslMode;
import io.asyncer.r2dbc.mysql.constant.ZeroDateOption;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.resolver.AddressResolverGroup;
import io.netty.resolver.DefaultAddressResolverGroup;
import io.r2dbc.spi.ConnectionFactories;
import io.r2dbc.spi.ConnectionFactoryOptions;
import io.r2dbc.spi.Option;
Expand Down Expand Up @@ -50,6 +52,7 @@
import java.util.stream.Stream;

import static io.asyncer.r2dbc.mysql.MySqlConnectionFactoryProvider.PASSWORD_PUBLISHER;
import static io.asyncer.r2dbc.mysql.MySqlConnectionFactoryProvider.RESOLVER;
import static io.asyncer.r2dbc.mysql.MySqlConnectionFactoryProvider.USE_SERVER_PREPARE_STATEMENT;
import static io.r2dbc.spi.ConnectionFactoryOptions.CONNECT_TIMEOUT;
import static io.r2dbc.spi.ConnectionFactoryOptions.DATABASE;
Expand Down Expand Up @@ -453,6 +456,19 @@ void validPasswordSupplier() {
assertThat(ConnectionFactories.get(options)).isExactlyInstanceOf(MySqlConnectionFactory.class);
}

@Test
void validResolver() {
final AddressResolverGroup<?> resolver = DefaultAddressResolverGroup.INSTANCE;
ConnectionFactoryOptions options = ConnectionFactoryOptions.builder()
.option(DRIVER, "mysql")
.option(HOST, "127.0.0.1")
.option(USER, "root")
.option(RESOLVER, resolver)
.build();

assertThat(ConnectionFactories.get(options)).isExactlyInstanceOf(MySqlConnectionFactory.class);
}

@Test
void allConfigurationOptions() {
List<String> exceptConfigs = Arrays.asList(
Expand Down

0 comments on commit 8a5e9cc

Please sign in to comment.