From cfef3f117a04109892e37799c139b80476baa92c Mon Sep 17 00:00:00 2001 From: Mirro Mutth Date: Tue, 12 Mar 2024 15:01:08 +0900 Subject: [PATCH] Add support for multiple hosts configuration - Allow to use Mono for user and password - Add multiple hosts connection strategy - Add HA protocol support for multiple hosts - Allow to use DNS SRV records for HA protocol --- .../r2dbc/mysql/ConnectionStrategy.java | 142 ++++ .../io/asyncer/r2dbc/mysql/Credential.java | 70 ++ .../mysql/MultiHostsConnectionStrategy.java | 207 +++++ .../r2dbc/mysql/MySqlBatchingBatch.java | 11 +- .../mysql/MySqlConnectionConfiguration.java | 719 ++++++++++-------- .../r2dbc/mysql/MySqlConnectionFactory.java | 114 +-- .../mysql/MySqlConnectionFactoryProvider.java | 156 ++-- .../r2dbc/mysql/MySqlSimpleConnection.java | 43 +- .../r2dbc/mysql/MySqlSslConfiguration.java | 95 ++- .../r2dbc/mysql/MySqlStatementSupport.java | 15 +- .../r2dbc/mysql/MySqlSyntheticBatch.java | 7 +- .../io/asyncer/r2dbc/mysql/OptionMapper.java | 31 +- .../mysql/ParametrizedStatementSupport.java | 11 +- .../mysql/PrepareParametrizedStatement.java | 8 +- .../r2dbc/mysql/PrepareSimpleStatement.java | 7 +- .../io/asyncer/r2dbc/mysql/QueryFlow.java | 88 ++- .../r2dbc/mysql/SimpleStatementSupport.java | 7 +- .../mysql/SingleHostConnectionStrategy.java | 52 ++ .../mysql/SocketClientConfiguration.java | 95 +++ .../r2dbc/mysql/SocketConfiguration.java | 28 + .../r2dbc/mysql/TcpSocketConfiguration.java | 235 ++++++ .../mysql/TextParametrizedStatement.java | 10 +- .../r2dbc/mysql/TextSimpleStatement.java | 7 +- .../mysql/UnixDomainSocketConfiguration.java | 75 ++ .../UnixDomainSocketConnectionStrategy.java | 48 ++ .../io/asyncer/r2dbc/mysql/client/Client.java | 64 +- .../mysql/client/ReactorNettyClient.java | 5 + .../r2dbc/mysql/constant/HaProtocol.java | 93 +++ .../r2dbc/mysql/constant/ProtocolDriver.java | 80 ++ .../r2dbc/mysql/internal/NodeAddress.java | 76 ++ .../mysql/internal/util/AddressUtils.java | 102 ++- .../mysql/internal/util/InternalArrays.java | 2 +- .../mysql/HaProtocolIntegrationTest.java | 81 ++ .../r2dbc/mysql/MySqlBatchingBatchTest.java | 6 +- .../MySqlConnectionConfigurationTest.java | 121 ++- .../MySqlConnectionFactoryProviderTest.java | 106 ++- .../mysql/MySqlSimpleConnectionTest.java | 28 +- .../r2dbc/mysql/MySqlSyntheticBatchTest.java | 3 +- .../r2dbc/mysql/MySqlTestKitSupport.java | 15 +- .../asyncer/r2dbc/mysql/OptionMapperTest.java | 18 +- .../PrepareParametrizedStatementTest.java | 8 +- .../mysql/PrepareSimpleStatementTest.java | 8 +- .../mysql/ProtocolDriverIntegrationTest.java | 64 ++ .../mysql/TextParametrizedStatementTest.java | 10 +- .../r2dbc/mysql/TextSimpleStatementTest.java | 9 +- .../r2dbc/mysql/TimeZoneIntegrationTest.java | 23 +- .../mysql/internal/util/AddressUtilsTest.java | 220 ++++-- 47 files changed, 2559 insertions(+), 864 deletions(-) create mode 100644 r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/ConnectionStrategy.java create mode 100644 r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/Credential.java create mode 100644 r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MultiHostsConnectionStrategy.java create mode 100644 r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/SingleHostConnectionStrategy.java create mode 100644 r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/SocketClientConfiguration.java create mode 100644 r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/SocketConfiguration.java create mode 100644 r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/TcpSocketConfiguration.java create mode 100644 r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/UnixDomainSocketConfiguration.java create mode 100644 r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/UnixDomainSocketConnectionStrategy.java create mode 100644 r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/constant/HaProtocol.java create mode 100644 r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/constant/ProtocolDriver.java create mode 100644 r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/internal/NodeAddress.java create mode 100644 r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/HaProtocolIntegrationTest.java create mode 100644 r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/ProtocolDriverIntegrationTest.java diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/ConnectionStrategy.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/ConnectionStrategy.java new file mode 100644 index 000000000..e7e232da2 --- /dev/null +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/ConnectionStrategy.java @@ -0,0 +1,142 @@ +/* + * Copyright 2024 asyncer.io projects + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.asyncer.r2dbc.mysql; + +import io.asyncer.r2dbc.mysql.client.Client; +import io.asyncer.r2dbc.mysql.constant.CompressionAlgorithm; +import io.asyncer.r2dbc.mysql.constant.SslMode; +import io.netty.channel.ChannelOption; +import io.netty.resolver.AddressResolver; +import io.netty.resolver.AddressResolverGroup; +import io.netty.resolver.DefaultNameResolver; +import io.netty.resolver.RoundRobinInetAddressResolver; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; +import reactor.core.publisher.Mono; +import reactor.netty.resources.LoopResources; +import reactor.netty.tcp.TcpClient; + +import java.net.InetSocketAddress; +import java.time.Duration; +import java.util.Set; + +/** + * An interface of a connection strategy that considers how to obtain a MySQL {@link Client} object. + * + * @since 1.2.0 + */ +@FunctionalInterface +interface ConnectionStrategy { + + InternalLogger logger = InternalLoggerFactory.getInstance(ConnectionStrategy.class); + + /** + * Establish a connection to a target server that is determined by this connection strategy. + * + * @return a logged-in {@link Client} object. + */ + Mono connect(); + + /** + * Creates a general-purpose {@link TcpClient} with the given {@link SocketClientConfiguration}. + *

+ * Note: Unix Domain Socket also uses this method to create a general-purpose {@link TcpClient client}. + * + * @param configuration socket client configuration. + * @return a general-purpose {@link TcpClient client}. + */ + static TcpClient createTcpClient(SocketClientConfiguration configuration, boolean balancedDns) { + LoopResources loopResources = configuration.getLoopResources(); + Duration connectTimeout = configuration.getConnectTimeout(); + TcpClient client = TcpClient.newConnection(); + + if (loopResources != null) { + client = client.runOn(loopResources); + } + + if (connectTimeout != null) { + client = client.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, Math.toIntExact(connectTimeout.toMillis())); + } + + if (balancedDns) { + client = client.resolver(BalancedResolverGroup.INSTANCE); + } + + return client; + } + + /** + * Logins to a MySQL server with the given {@link TcpClient}, {@link Credential} and configurations. + * + * @param tcpClient a TCP client to connect to a MySQL server. + * @param credential user and password to log in to a MySQL server. + * @param configuration a configuration that affects login behavior. + * @return a logged-in {@link Client} object. + */ + static Mono login( + TcpClient tcpClient, + Credential credential, + MySqlConnectionConfiguration configuration + ) { + MySqlSslConfiguration ssl = configuration.getSsl(); + SslMode sslMode = ssl.getSslMode(); + boolean createDbIfNotExist = configuration.isCreateDatabaseIfNotExist(); + String database = configuration.getDatabase(); + String loginDb = createDbIfNotExist ? "" : database; + Set compressionAlgorithms = configuration.getCompressionAlgorithms(); + int zstdLevel = configuration.getZstdCompressionLevel(); + ConnectionContext context = new ConnectionContext( + configuration.getZeroDateOption(), + configuration.getLoadLocalInfilePath(), + configuration.getLocalInfileBufferSize(), + configuration.isPreserveInstants(), + configuration.retrieveConnectionZoneId() + ); + + return Client.connect(tcpClient, ssl, context).flatMap(client -> + QueryFlow.login(client, sslMode, loginDb, credential, compressionAlgorithms, zstdLevel, context)); + } +} + +/** + * Resolves the {@link InetSocketAddress} to IP address, randomly select one if it resolves to multiple IP addresses. + *

+ * Note: DNS resolution should have no relation to the connection strategy of HA protocol. + * + * @since 1.2.0 + */ +final class BalancedResolverGroup extends AddressResolverGroup { + + BalancedResolverGroup() { + } + + public static final BalancedResolverGroup INSTANCE; + + static { + INSTANCE = new BalancedResolverGroup(); + Runtime.getRuntime().addShutdownHook(new Thread( + INSTANCE::close, + "R2DBC-MySQL-BalancedResolverGroup-ShutdownHook" + )); + } + + @Override + protected AddressResolver newResolver(EventExecutor executor) { + return new RoundRobinInetAddressResolver(executor, new DefaultNameResolver(executor)).asAddressResolver(); + } +} diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/Credential.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/Credential.java new file mode 100644 index 000000000..82cb1168d --- /dev/null +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/Credential.java @@ -0,0 +1,70 @@ +/* + * Copyright 2024 asyncer.io projects + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.asyncer.r2dbc.mysql; + +import org.jetbrains.annotations.Nullable; + +import java.util.Objects; + +/** + * A value object representing a user with an optional password. + */ +final class Credential { + + private final String user; + + @Nullable + private final CharSequence password; + + Credential(String user, @Nullable CharSequence password) { + this.user = user; + this.password = password; + } + + String getUser() { + return user; + } + + @Nullable + CharSequence getPassword() { + return password; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof Credential)) { + return false; + } + + Credential that = (Credential) o; + + return user.equals(that.user) && Objects.equals(password, that.password); + } + + @Override + public int hashCode() { + return 31 * user.hashCode() + Objects.hashCode(password); + } + + @Override + public String toString() { + return "Credential{user=" + user + ", password=REDACTED}"; + } +} diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MultiHostsConnectionStrategy.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MultiHostsConnectionStrategy.java new file mode 100644 index 000000000..b2f91195e --- /dev/null +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MultiHostsConnectionStrategy.java @@ -0,0 +1,207 @@ +/* + * Copyright 2024 asyncer.io projects + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.asyncer.r2dbc.mysql; + +import io.asyncer.r2dbc.mysql.client.Client; +import io.asyncer.r2dbc.mysql.constant.ProtocolDriver; +import io.asyncer.r2dbc.mysql.internal.NodeAddress; +import io.asyncer.r2dbc.mysql.internal.util.InternalArrays; +import io.netty.channel.ChannelOption; +import io.netty.channel.EventLoop; +import io.netty.resolver.DefaultNameResolver; +import io.netty.resolver.NameResolver; +import io.netty.util.concurrent.Future; +import io.r2dbc.spi.R2dbcNonTransientResourceException; +import org.jetbrains.annotations.Nullable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.netty.tcp.TcpClient; +import reactor.netty.tcp.TcpResources; + +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.function.Function; + +/** + * An abstraction for {@link ConnectionStrategy} that consider multiple hosts. + */ +final class MultiHostsConnectionStrategy implements ConnectionStrategy { + + private final Mono client; + + MultiHostsConnectionStrategy( + TcpSocketConfiguration tcp, + MySqlConnectionConfiguration configuration, + boolean shuffle + ) { + this.client = Mono.defer(() -> { + if (ProtocolDriver.DNS_SRV.equals(tcp.getDriver())) { + return resolveAllHosts(TcpResources.get().onClient(true).next(), tcp.getAddresses(), shuffle) + .flatMap(addresses -> connectHost(addresses, tcp, configuration, false, shuffle, 0)); + } else { + List availableHosts = copyAvailableAddresses(tcp.getAddresses(), shuffle); + int size = availableHosts.size(); + InetSocketAddress[] addresses = new InetSocketAddress[availableHosts.size()]; + + for (int i = 0; i < size; i++) { + NodeAddress address = availableHosts.get(i); + addresses[i] = InetSocketAddress.createUnresolved(address.getHost(), address.getPort()); + } + + return connectHost(InternalArrays.asImmutableList(addresses), tcp, configuration, true, shuffle, 0); + } + }); + } + + @Override + public Mono connect() { + return client; + } + + private Mono connectHost( + List addresses, + TcpSocketConfiguration tcp, + MySqlConnectionConfiguration configuration, + boolean balancedDns, + boolean shuffle, + int attempts + ) { + Iterator iter = addresses.iterator(); + + if (!iter.hasNext()) { + return Mono.error(fail("Fail to establish connection: no available host", null)); + } + + return attemptConnect(iter.next(), tcp, configuration, balancedDns).onErrorResume(t -> + resumeConnect(t, addresses, iter, tcp, configuration, balancedDns, shuffle, attempts)); + } + + private Mono resumeConnect( + Throwable t, + List addresses, + Iterator iter, + TcpSocketConfiguration tcp, + MySqlConnectionConfiguration configuration, + boolean balancedDns, + boolean shuffle, + int attempts + ) { + if (!iter.hasNext()) { + // The last host failed to connect + if (attempts >= tcp.getRetriesAllDown()) { + return Mono.error(fail( + "Fail to establish connection, retried " + attempts + " times: " + t.getMessage(), t)); + } + + logger.warn("All hosts failed to establish connections, auto-try again after 250ms."); + + // Ignore waiting error, e.g. interrupted, scheduler rejected + return Mono.delay(Duration.ofMillis(250)) + .onErrorComplete() + .then(Mono.defer(() -> connectHost(addresses, tcp, configuration, balancedDns, shuffle, attempts + 1))); + } + + return attemptConnect(iter.next(), tcp, configuration, balancedDns).onErrorResume(tt -> + resumeConnect(tt, addresses, iter, tcp, configuration, balancedDns, shuffle, attempts)); + } + + private Mono attemptConnect( + InetSocketAddress address, + TcpSocketConfiguration tcp, + MySqlConnectionConfiguration configuration, + boolean balancedDns + ) { + return configuration.getCredential().flatMap(credential -> { + TcpClient tcpClient = ConnectionStrategy.createTcpClient(configuration.getClient(), balancedDns) + .option(ChannelOption.SO_KEEPALIVE, tcp.isTcpKeepAlive()) + .option(ChannelOption.TCP_NODELAY, tcp.isTcpNoDelay()) + .remoteAddress(() -> address); + + return ConnectionStrategy.login(tcpClient, credential, configuration); + }).doOnError(e -> logger.warn("Fail to connect: ", e)); + } + + private static Mono> resolveAllHosts( + EventLoop eventLoop, + List addresses, + boolean shuffle + ) { + // Or DnsNameResolver? It is non-blocking but requires native dependencies, hard configurations, and maybe + // behaves differently. Currently, we use DefaultNameResolver which is blocking but simple and easy to use. + DefaultNameResolver resolver = new DefaultNameResolver(eventLoop); + + return Flux.fromIterable(addresses) + .flatMap(address -> resolveAll(resolver, address.getHost()) + .flatMapIterable(Function.identity()) + .map(inet -> new InetSocketAddress(inet, address.getPort()))) + .doFinally(ignore -> resolver.close()) + .collectList() + .map(list -> { + if (shuffle) { + Collections.shuffle(list); + } + + return list; + }); + } + + private static Mono> resolveAll(NameResolver resolver, String host) { + Future> future = resolver.resolveAll(host); + + return Mono.>create(sink -> future.addListener(f -> { + if (f.isSuccess()) { + try { + @SuppressWarnings("unchecked") + List t = (List) f.getNow(); + + logger.debug("Resolve {} in DNS succeed, {} records", host, t.size()); + sink.success(t); + } catch (Throwable e) { + logger.warn("Resolve {} in DNS succeed but failed to get result", host, e); + sink.success(Collections.emptyList()); + } + } else { + logger.warn("Resolve {} in DNS failed", host, f.cause()); + sink.success(Collections.emptyList()); + } + })).doOnCancel(() -> future.cancel(false)); + } + + private static List copyAvailableAddresses(List addresses, boolean shuffle) { + if (shuffle) { + List copied = new ArrayList<>(addresses); + Collections.shuffle(copied); + return copied; + } + + return InternalArrays.asImmutableList(addresses.toArray(new NodeAddress[0])); + } + + private static R2dbcNonTransientResourceException fail(String message, @Nullable Throwable cause) { + return new R2dbcNonTransientResourceException( + message, + "H1000", + 9000, + cause + ); + } +} diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlBatchingBatch.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlBatchingBatch.java index d85ebfd9e..158cc1799 100644 --- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlBatchingBatch.java +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlBatchingBatch.java @@ -27,8 +27,8 @@ import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonNull; /** - * An implementation of {@link MySqlBatch} for executing a collection of statements in a batch against the - * MySQL database. + * An implementation of {@link MySqlBatch} for executing a collection of statements in a batch against the MySQL + * database. */ final class MySqlBatchingBatch implements MySqlBatch { @@ -36,14 +36,11 @@ final class MySqlBatchingBatch implements MySqlBatch { private final Codecs codecs; - private final ConnectionContext context; - private final StringJoiner queries = new StringJoiner(";"); - MySqlBatchingBatch(Client client, Codecs codecs, ConnectionContext context) { + MySqlBatchingBatch(Client client, Codecs codecs) { this.client = requireNonNull(client, "client must not be null"); this.codecs = requireNonNull(codecs, "codecs must not be null"); - this.context = requireNonNull(context, "context must not be null"); } @Override @@ -65,7 +62,7 @@ public MySqlBatch add(String sql) { @Override public Flux execute() { return QueryFlow.execute(client, getSql()) - .map(messages -> MySqlSegmentResult.toResult(false, codecs, context, null, messages)); + .map(messages -> MySqlSegmentResult.toResult(false, codecs, client.getContext(), null, messages)); } @Override diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionConfiguration.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionConfiguration.java index 5953495ce..38bbb0592 100644 --- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionConfiguration.java +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionConfiguration.java @@ -17,13 +17,17 @@ package io.asyncer.r2dbc.mysql; import io.asyncer.r2dbc.mysql.constant.CompressionAlgorithm; +import io.asyncer.r2dbc.mysql.constant.HaProtocol; +import io.asyncer.r2dbc.mysql.constant.ProtocolDriver; import io.asyncer.r2dbc.mysql.constant.SslMode; import io.asyncer.r2dbc.mysql.constant.ZeroDateOption; import io.asyncer.r2dbc.mysql.extension.Extension; import io.asyncer.r2dbc.mysql.internal.util.InternalArrays; +import io.asyncer.r2dbc.mysql.internal.util.StringUtils; import io.netty.handler.ssl.SslContextBuilder; import org.jetbrains.annotations.Nullable; import org.reactivestreams.Publisher; +import reactor.core.publisher.Mono; import reactor.netty.resources.LoopResources; import reactor.netty.tcp.TcpResources; @@ -38,47 +42,28 @@ import java.util.EnumSet; import java.util.List; import java.util.Objects; +import java.util.Optional; import java.util.ServiceLoader; import java.util.Set; import java.util.function.Function; import java.util.function.Predicate; +import java.util.function.Supplier; import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.require; import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonEmpty; import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonNull; -import static io.asyncer.r2dbc.mysql.internal.util.InternalArrays.EMPTY_STRINGS; /** * A configuration of MySQL connection. */ public final class MySqlConnectionConfiguration { - /** - * Default MySQL port. - */ - private static final int DEFAULT_PORT = 3306; - - /** - * {@code true} if {@link #domain} is hostname, otherwise {@link #domain} is unix domain socket path. - */ - private final boolean isHost; - - /** - * Domain of connecting, may be hostname or unix domain socket path. - */ - private final String domain; + private final SocketClientConfiguration client; - private final int port; + private final SocketConfiguration socket; private final MySqlSslConfiguration ssl; - private final boolean tcpKeepAlive; - - private final boolean tcpNoDelay; - - @Nullable - private final Duration connectTimeout; - private final boolean preserveInstants; private final String connectionTimeZone; @@ -87,10 +72,9 @@ public final class MySqlConnectionConfiguration { private final ZeroDateOption zeroDateOption; - private final String user; + private final Mono user; - @Nullable - private final CharSequence password; + private final Mono> password; private final String database; @@ -114,42 +98,35 @@ public final class MySqlConnectionConfiguration { private final int zstdCompressionLevel; - private final LoopResources loopResources; - private final Extensions extensions; - @Nullable - private final Publisher passwordPublisher; - private MySqlConnectionConfiguration( - boolean isHost, String domain, int port, MySqlSslConfiguration ssl, - boolean tcpKeepAlive, boolean tcpNoDelay, @Nullable Duration connectTimeout, + SocketClientConfiguration client, + SocketConfiguration socket, + MySqlSslConfiguration ssl, ZeroDateOption zeroDateOption, boolean preserveInstants, String connectionTimeZone, boolean forceConnectionTimeZoneToSession, - String user, @Nullable CharSequence password, @Nullable String database, + Mono user, + Mono> password, + @Nullable String database, boolean createDatabaseIfNotExist, @Nullable Predicate preferPrepareStatement, List sessionVariables, @Nullable Path loadLocalInfilePath, int localInfileBufferSize, int queryCacheSize, int prepareCacheSize, Set compressionAlgorithms, int zstdCompressionLevel, - @Nullable LoopResources loopResources, - Extensions extensions, @Nullable Publisher passwordPublisher + Extensions extensions ) { - this.isHost = isHost; - this.domain = domain; - this.port = port; - this.tcpKeepAlive = tcpKeepAlive; - this.tcpNoDelay = tcpNoDelay; - this.connectTimeout = connectTimeout; - this.ssl = ssl; + this.client = requireNonNull(client, "client must not be null"); + this.socket = requireNonNull(socket, "socket must not be null"); + this.ssl = requireNonNull(ssl, "ssl must not be null"); this.preserveInstants = preserveInstants; this.connectionTimeZone = requireNonNull(connectionTimeZone, "connectionTimeZone must not be null"); this.forceConnectionTimeZoneToSession = forceConnectionTimeZoneToSession; this.zeroDateOption = requireNonNull(zeroDateOption, "zeroDateOption must not be null"); this.user = requireNonNull(user, "user must not be null"); - this.password = password; + this.password = requireNonNull(password, "password must not be null"); this.database = database == null || database.isEmpty() ? "" : database; this.createDatabaseIfNotExist = createDatabaseIfNotExist; this.preferPrepareStatement = preferPrepareStatement; @@ -160,9 +137,7 @@ private MySqlConnectionConfiguration( this.prepareCacheSize = prepareCacheSize; this.compressionAlgorithms = compressionAlgorithms; this.zstdCompressionLevel = zstdCompressionLevel; - this.loopResources = loopResources == null ? TcpResources.get() : loopResources; this.extensions = extensions; - this.passwordPublisher = passwordPublisher; } /** @@ -174,35 +149,18 @@ public static Builder builder() { return new Builder(); } - boolean isHost() { - return isHost; - } - - String getDomain() { - return domain; + SocketClientConfiguration getClient() { + return client; } - int getPort() { - return port; - } - - @Nullable - Duration getConnectTimeout() { - return connectTimeout; + SocketConfiguration getSocket() { + return socket; } MySqlSslConfiguration getSsl() { return ssl; } - boolean isTcpKeepAlive() { - return this.tcpKeepAlive; - } - - boolean isTcpNoDelay() { - return this.tcpNoDelay; - } - ZeroDateOption getZeroDateOption() { return zeroDateOption; } @@ -215,17 +173,25 @@ String getConnectionTimeZone() { return connectionTimeZone; } - boolean isForceConnectionTimeZoneToSession() { - return forceConnectionTimeZoneToSession; + @Nullable + ZoneId retrieveConnectionZoneId() { + String timeZone = this.connectionTimeZone; + + if ("LOCAL".equalsIgnoreCase(timeZone)) { + return ZoneId.systemDefault().normalized(); + } else if ("SERVER".equalsIgnoreCase(timeZone)) { + return null; + } + + return StringUtils.parseZoneId(timeZone); } - String getUser() { - return user; + boolean isForceConnectionTimeZoneToSession() { + return forceConnectionTimeZoneToSession; } - @Nullable - CharSequence getPassword() { - return password; + Mono getCredential() { + return Mono.zip(user, password, (u, p) -> new Credential(u, p.orElse(null))); } String getDatabase() { @@ -270,19 +236,10 @@ int getZstdCompressionLevel() { return zstdCompressionLevel; } - LoopResources getLoopResources() { - return loopResources; - } - Extensions getExtensions() { return extensions; } - @Nullable - Publisher getPasswordPublisher() { - return passwordPublisher; - } - @Override public boolean equals(Object o) { if (this == o) { @@ -291,20 +248,18 @@ public boolean equals(Object o) { if (!(o instanceof MySqlConnectionConfiguration)) { return false; } + MySqlConnectionConfiguration that = (MySqlConnectionConfiguration) o; - return isHost == that.isHost && - domain.equals(that.domain) && - port == that.port && + + return client.equals(that.client) && + socket.equals(that.socket) && ssl.equals(that.ssl) && - tcpKeepAlive == that.tcpKeepAlive && - tcpNoDelay == that.tcpNoDelay && - Objects.equals(connectTimeout, that.connectTimeout) && preserveInstants == that.preserveInstants && - Objects.equals(connectionTimeZone, that.connectionTimeZone) && + connectionTimeZone.equals(that.connectionTimeZone) && forceConnectionTimeZoneToSession == that.forceConnectionTimeZoneToSession && zeroDateOption == that.zeroDateOption && user.equals(that.user) && - Objects.equals(password, that.password) && + password.equals(that.password) && database.equals(that.database) && createDatabaseIfNotExist == that.createDatabaseIfNotExist && Objects.equals(preferPrepareStatement, that.preferPrepareStatement) && @@ -315,50 +270,48 @@ public boolean equals(Object o) { prepareCacheSize == that.prepareCacheSize && compressionAlgorithms.equals(that.compressionAlgorithms) && zstdCompressionLevel == that.zstdCompressionLevel && - Objects.equals(loopResources, that.loopResources) && - extensions.equals(that.extensions) && - Objects.equals(passwordPublisher, that.passwordPublisher); + extensions.equals(that.extensions); } @Override public int hashCode() { - return Objects.hash(isHost, domain, port, ssl, tcpKeepAlive, tcpNoDelay, connectTimeout, - preserveInstants, connectionTimeZone, forceConnectionTimeZoneToSession, - zeroDateOption, user, password, database, createDatabaseIfNotExist, - preferPrepareStatement, sessionVariables, loadLocalInfilePath, - localInfileBufferSize, queryCacheSize, prepareCacheSize, compressionAlgorithms, - zstdCompressionLevel, loopResources, extensions, passwordPublisher); + int result = client.hashCode(); + + result = 31 * result + socket.hashCode(); + result = 31 * result + ssl.hashCode(); + result = 31 * result + (preserveInstants ? 1 : 0); + result = 31 * result + connectionTimeZone.hashCode(); + result = 31 * result + (forceConnectionTimeZoneToSession ? 1 : 0); + result = 31 * result + zeroDateOption.hashCode(); + result = 31 * result + user.hashCode(); + result = 31 * result + password.hashCode(); + result = 31 * result + database.hashCode(); + result = 31 * result + (createDatabaseIfNotExist ? 1 : 0); + result = 31 * result + (preferPrepareStatement != null ? preferPrepareStatement.hashCode() : 0); + result = 31 * result + sessionVariables.hashCode(); + result = 31 * result + (loadLocalInfilePath != null ? loadLocalInfilePath.hashCode() : 0); + result = 31 * result + localInfileBufferSize; + result = 31 * result + queryCacheSize; + result = 31 * result + prepareCacheSize; + result = 31 * result + compressionAlgorithms.hashCode(); + result = 31 * result + zstdCompressionLevel; + + return 31 * result + extensions.hashCode(); } @Override public String toString() { - if (isHost) { - return "MySqlConnectionConfiguration{host='" + domain + "', port=" + port + ", ssl=" + ssl + - ", tcpNoDelay=" + tcpNoDelay + ", tcpKeepAlive=" + tcpKeepAlive + - ", connectTimeout=" + connectTimeout + - ", preserveInstants=" + preserveInstants + - ", connectionTimeZone=" + connectionTimeZone + - ", forceConnectionTimeZoneToSession=" + forceConnectionTimeZoneToSession + - ", zeroDateOption=" + zeroDateOption + ", user='" + user + "', password=" + password + - ", database='" + database + "', createDatabaseIfNotExist=" + createDatabaseIfNotExist + - ", preferPrepareStatement=" + preferPrepareStatement + - ", sessionVariables=" + sessionVariables + - ", loadLocalInfilePath=" + loadLocalInfilePath + - ", localInfileBufferSize=" + localInfileBufferSize + - ", queryCacheSize=" + queryCacheSize + ", prepareCacheSize=" + prepareCacheSize + - ", compressionAlgorithms=" + compressionAlgorithms + - ", zstdCompressionLevel=" + zstdCompressionLevel + - ", loopResources=" + loopResources + - ", extensions=" + extensions + ", passwordPublisher=" + passwordPublisher + '}'; - } - - return "MySqlConnectionConfiguration{unixSocket='" + domain + - "', connectTimeout=" + connectTimeout + + return "MySqlConnectionConfiguration{client=" + client + + ", socket=" + socket + + ", ssl=" + ssl + ", preserveInstants=" + preserveInstants + - ", connectionTimeZone=" + connectionTimeZone + + ", connectionTimeZone='" + connectionTimeZone + '\'' + ", forceConnectionTimeZoneToSession=" + forceConnectionTimeZoneToSession + - ", zeroDateOption=" + zeroDateOption + ", user='" + user + "', password=" + password + - ", database='" + database + "', createDatabaseIfNotExist=" + createDatabaseIfNotExist + + ", zeroDateOption=" + zeroDateOption + + ", user=" + user + + ", password=REDACTED" + + ", database='" + database + '\'' + + ", createDatabaseIfNotExist=" + createDatabaseIfNotExist + ", preferPrepareStatement=" + preferPrepareStatement + ", sessionVariables=" + sessionVariables + ", loadLocalInfilePath=" + loadLocalInfilePath + @@ -367,8 +320,8 @@ public String toString() { ", prepareCacheSize=" + prepareCacheSize + ", compressionAlgorithms=" + compressionAlgorithms + ", zstdCompressionLevel=" + zstdCompressionLevel + - ", loopResources=" + loopResources + - ", extensions=" + extensions + ", passwordPublisher=" + passwordPublisher + '}'; + ", extensions=" + extensions + + '}'; } /** @@ -376,24 +329,26 @@ public String toString() { */ public static final class Builder { - @Nullable - private String database; + private final SocketClientConfiguration.Builder client = new SocketClientConfiguration.Builder(); - private boolean createDatabaseIfNotExist; + @Nullable + private TcpSocketConfiguration.Builder tcpSocket; - private boolean isHost = true; + @Nullable + private UnixDomainSocketConfiguration.Builder unixSocket; - private String domain; + private final MySqlSslConfiguration.Builder ssl = new MySqlSslConfiguration.Builder(); @Nullable - private CharSequence password; + private String database; - private int port = DEFAULT_PORT; + private boolean createDatabaseIfNotExist; @Nullable - private Duration connectTimeout; + private Mono user; - private String user; + @Nullable + private Mono password; private ZeroDateOption zeroDateOption = ZeroDateOption.USE_NULL; @@ -403,33 +358,6 @@ public static final class Builder { private boolean forceConnectionTimeZoneToSession; - @Nullable - private SslMode sslMode; - - private String[] tlsVersion = EMPTY_STRINGS; - - @Nullable - private HostnameVerifier sslHostnameVerifier; - - @Nullable - private String sslCa; - - @Nullable - private String sslKey; - - @Nullable - private CharSequence sslKeyPassword; - - @Nullable - private String sslCert; - - @Nullable - private Function sslContextBuilderCustomizer; - - private boolean tcpKeepAlive; - - private boolean tcpNoDelay; - @Nullable private Predicate preferPrepareStatement; @@ -449,54 +377,63 @@ public static final class Builder { private int zstdCompressionLevel = 3; - @Nullable - private LoopResources loopResources; - private boolean autodetectExtensions = true; private final List extensions = new ArrayList<>(); - @Nullable - private Publisher passwordPublisher; - /** * Builds an immutable {@link MySqlConnectionConfiguration} with current options. * * @return the {@link MySqlConnectionConfiguration}. */ public MySqlConnectionConfiguration build() { - SslMode sslMode = requireSslMode(); - - if (isHost) { - requireNonNull(domain, "host must not be null when using TCP socket"); - require((sslCert == null && sslKey == null) || (sslCert != null && sslKey != null), - "sslCert and sslKey must be both null or both non-null"); + Mono user = requireNonNull(this.user, "User must be configured"); + Mono auth = this.password; + Mono> password = auth == null ? Mono.just(Optional.empty()) : auth.singleOptional(); + SocketConfiguration socket; + boolean preferredSsl; + + if (unixSocket == null) { + socket = requireNonNull(tcpSocket, "Connection must be either TCP/SSL or Unix Domain Socket").build(); + preferredSsl = true; } else { - requireNonNull(domain, "unixSocket must not be null when using unix domain socket"); - require(!sslMode.startSsl(), "sslMode must be disabled when using unix domain socket"); + // Since 1.2.0, we support SSL over Unix Domain Socket, default SSL mode is DISABLED. + // But, if a Unix Domain Socket can be listened to by someone, this indicates that the system itself + // has been compromised, and enabling SSL does not improve the security of the connection. + socket = unixSocket.build(); + preferredSsl = false; } int prepareCacheSize = preferPrepareStatement == null ? 0 : this.prepareCacheSize; - MySqlSslConfiguration ssl = MySqlSslConfiguration.create(sslMode, tlsVersion, sslHostnameVerifier, - sslCa, sslKey, sslKeyPassword, sslCert, sslContextBuilderCustomizer); - return new MySqlConnectionConfiguration(isHost, domain, port, ssl, tcpKeepAlive, tcpNoDelay, - connectTimeout, zeroDateOption, + return new MySqlConnectionConfiguration( + client.build(), + socket, + ssl.build(preferredSsl), + zeroDateOption, preserveInstants, connectionTimeZone, forceConnectionTimeZoneToSession, - user, password, database, - createDatabaseIfNotExist, preferPrepareStatement, sessionVariables, loadLocalInfilePath, - localInfileBufferSize, queryCacheSize, prepareCacheSize, - compressionAlgorithms, zstdCompressionLevel, loopResources, - Extensions.from(extensions, autodetectExtensions), passwordPublisher); + user.single(), + password, + database, + createDatabaseIfNotExist, + preferPrepareStatement, + sessionVariables, + loadLocalInfilePath, + localInfileBufferSize, + queryCacheSize, + prepareCacheSize, + compressionAlgorithms, + zstdCompressionLevel, + Extensions.from(extensions, autodetectExtensions)); } /** * Configures the database. Default no database. * * @param database the database, or {@code null} if no database want to be login. - * @return this {@link Builder}. + * @return {@link Builder this} * @since 0.8.1 */ public Builder database(@Nullable String database) { @@ -509,7 +446,7 @@ public Builder database(@Nullable String database) { * {@code false}. * * @param enabled to discover and register extensions. - * @return this {@link Builder}. + * @return {@link Builder this} * @since 1.0.6 */ public Builder createDatabaseIfNotExist(boolean enabled) { @@ -519,58 +456,116 @@ public Builder createDatabaseIfNotExist(boolean enabled) { /** * Configures the Unix Domain Socket to connect to. + *

+ * Note: It will override all TCP and SSL configurations if configured. * - * @param unixSocket the socket file path. - * @return this {@link Builder}. - * @throws IllegalArgumentException if {@code unixSocket} is {@code null}. + * @param path the socket file path. + * @return {@link Builder this} + * @throws IllegalArgumentException if {@code path} is {@code null}. * @since 0.8.1 */ - public Builder unixSocket(String unixSocket) { - this.domain = requireNonNull(unixSocket, "unixSocket must not be null"); - this.isHost = false; + public Builder unixSocket(String path) { + requireNonNull(path, "path must not be null"); + + requireUnixSocket().path(path); return this; } /** - * Configures the host. + * Configures the single-host. + *

+ * Note: Used only if the {@link #unixSocket(String)} and {@link #addHost multiple hosts} is not configured. * * @param host the host. - * @return this {@link Builder}. + * @return {@link Builder this} * @throws IllegalArgumentException if {@code host} is {@code null}. * @since 0.8.1 */ public Builder host(String host) { - this.domain = requireNonNull(host, "host must not be null"); - this.isHost = true; + requireNonEmpty(host, "host must not be empty"); + + requireTcpSocket().host(host); return this; } /** - * Configures the password. Default login without password. + * Configures the port of {@link #host(String)}. Defaults to {@code 3306}. *

- * Note: for memory security, should not use intern {@link String} for password. + * Note: Used only if the {@link #unixSocket(String)} and {@link #addHost multiple hosts} is not configured. * - * @param password the password, or {@code null} when user has no password. - * @return this {@link Builder}. + * @param port the port. + * @return {@link Builder this} + * @throws IllegalArgumentException if the {@code port} is negative or bigger than {@literal 65535}. * @since 0.8.1 */ - public Builder password(@Nullable CharSequence password) { - this.password = password; + public Builder port(int port) { + require(port >= 0 && port <= 0xFFFF, "port must be between 0 and 65535"); + + requireTcpSocket().port(port); return this; } /** - * Configures the port. Defaults to {@code 3306}. + * Adds a host with default port 3306 to the list of multiple hosts to connect to. + *

+ * Note: Used only if the {@link #unixSocket(String)} and {@link #host single host} is not configured. * - * @param port the port. - * @return this {@link Builder}. - * @throws IllegalArgumentException if the {@code port} is negative or bigger than {@literal 65535}. - * @since 0.8.1 + * @param host the host to add. + * @return {@link Builder this} + * @since 1.2.0 */ - public Builder port(int port) { + public Builder addHost(String host) { + requireNonEmpty(host, "host must not be empty"); + + requireTcpSocket().addHost(host); + return this; + } + + /** + * Adds a host to the list of multiple hosts to connect to. + *

+ * Note: Used only if the {@link #unixSocket(String)} and {@link #host single host} is not configured. + * + * @param host the host to add. + * @param port the port of the host. + * @return {@link Builder this} + * @since 1.2.0 + */ + public Builder addHost(String host, int port) { + requireNonEmpty(host, "host must not be empty"); require(port >= 0 && port <= 0xFFFF, "port must be between 0 and 65535"); - this.port = port; + requireTcpSocket().addHost(host, port); + return this; + } + + /** + * Configures the failover and high availability protocol driver. Default to {@link ProtocolDriver#MYSQL}. Used + * only if the {@link #unixSocket(String)} is not configured. + * + * @param driver the protocol driver. + * @return {@link Builder this} + * @since 1.2.0 + */ + public Builder driver(ProtocolDriver driver) { + requireNonNull(driver, "driver must not be null"); + + requireTcpSocket().driver(driver); + return this; + } + + /** + * Configures the failover and high availability protocol. Default to {@link HaProtocol#DEFAULT}. Used only if + * the {@link #unixSocket(String)} is not configured. + * + * @param protocol the failover and high availability protocol. + * @return {@link Builder this} + * @since 1.2.0 + */ + public Builder protocol(HaProtocol protocol) { + requireNonNull(protocol, "protocol must not be null"); + + requireTcpSocket().protocol(protocol); return this; } @@ -578,11 +573,11 @@ public Builder port(int port) { * Configures the connection timeout. Default no timeout. * * @param connectTimeout the connection timeout, or {@code null} if no timeout. - * @return this {@link Builder}. + * @return {@link Builder this} * @since 0.8.1 */ public Builder connectTimeout(@Nullable Duration connectTimeout) { - this.connectTimeout = connectTimeout; + this.client.connectTimeout(connectTimeout); return this; } @@ -590,20 +585,50 @@ public Builder connectTimeout(@Nullable Duration connectTimeout) { * Configures the user for login the database. * * @param user the user. - * @return this {@link Builder}. - * @throws IllegalArgumentException if {@code user} is {@code null}. + * @return {@link Builder this} + * @throws IllegalArgumentException if {@code user} is empty. * @since 0.8.2 */ public Builder user(String user) { - this.user = requireNonNull(user, "user must not be null"); + requireNonEmpty(user, "user must not be null"); + + this.user = Mono.just(user); return this; } /** - * An alias of {@link #user(String)}. + * Configures the user for login the database. + * + * @param user a {@link Supplier} to retrieve user. + * @return {@link Builder this} + * @since 1.2.0 + */ + public Builder user(Supplier user) { + requireNonNull(user, "user must not be null"); + + this.user = Mono.fromSupplier(user); + return this; + } + + /** + * Configures the user for login the database. + * + * @param user a {@link Publisher} to retrieve user. + * @return {@link Builder this} + * @since 1.2.0 + */ + public Builder user(Publisher user) { + requireNonNull(user, "user must not be null"); + + this.user = Mono.from(user); + return this; + } + + /** + * Configures the user for login the database. Since 0.8.2, it is an alias of {@link #user(String)}. * * @param user the user. - * @return this {@link Builder}. + * @return {@link Builder this} * @throws IllegalArgumentException if {@code user} is {@code null}. * @since 0.8.1 */ @@ -612,8 +637,46 @@ public Builder username(String user) { } /** - * Configures the time zone conversion. Default to {@code true} means enable conversion between JVM - * and {@link #connectionTimeZone(String)}. + * Configures the password. Default login without password. + *

+ * Note: for memory security, should not use intern {@link String} for password. + * + * @param password the password, or {@code null} when user has no password. + * @return {@link Builder this} + * @since 0.8.1 + */ + public Builder password(@Nullable CharSequence password) { + this.password = Mono.justOrEmpty(password); + return this; + } + + /** + * Configures the password. Default login without password. + * + * @param password a {@link Supplier} to retrieve password. + * @return {@link Builder this} + * @since 1.2.0 + */ + public Builder password(Supplier password) { + this.password = Mono.fromSupplier(password); + return this; + } + + /** + * Configures the password. Default login without password. + * + * @param password a {@link Publisher} to retrieve password. + * @return {@link Builder this} + * @since 1.2.0 + */ + public Builder password(Publisher password) { + this.password = Mono.from(password); + return this; + } + + /** + * Configures the time zone conversion. Default to {@code true} means enable conversion between JVM and + * {@link #connectionTimeZone(String)}. *

* Note: disable it will ignore the time zone of connection, and use the JVM local time zone. * @@ -643,8 +706,8 @@ public Builder connectionTimeZone(String connectionTimeZone) { } /** - * Configures to force the connection time zone to session time zone. Default to {@code false}. Used - * only if the {@link #connectionTimeZone(String)} is not {@code "SERVER"}. + * Configures to force the connection time zone to session time zone. Default to {@code false}. Used only if + * the {@link #connectionTimeZone(String)} is not {@code "SERVER"}. *

* Note: alter the time zone of session will affect the results of MySQL date/time functions, e.g. * {@code NOW([n])}, {@code CURRENT_TIME([n])}, {@code CURRENT_DATE()}, etc. Please use with caution. @@ -672,11 +735,11 @@ public Builder serverZoneId(@Nullable ZoneId serverZoneId) { } /** - * Configures the {@link ZeroDateOption}. Default to {@link ZeroDateOption#USE_NULL}. It is a - * behavior option when this driver receives a value of zero-date. + * Configures the {@link ZeroDateOption}. Default to {@link ZeroDateOption#USE_NULL}. It is a behavior option + * when this driver receives a value of zero-date. * * @param zeroDate the {@link ZeroDateOption}. - * @return this {@link Builder}. + * @return {@link Builder this} * @throws IllegalArgumentException if {@code zeroDate} is {@code null}. * @since 0.8.1 */ @@ -687,46 +750,46 @@ public Builder zeroDateOption(ZeroDateOption zeroDate) { /** * Configures ssl mode. See also {@link SslMode}. + *

+ * Note: It is used only if the {@link #unixSocket(String)} is not configured. * * @param sslMode the SSL mode to use. - * @return this {@link Builder}. + * @return {@link Builder this} * @throws IllegalArgumentException if {@code sslMode} is {@code null}. * @since 0.8.1 */ public Builder sslMode(SslMode sslMode) { - this.sslMode = requireNonNull(sslMode, "sslMode must not be null"); + requireNonNull(sslMode, "sslMode must not be null"); + + this.ssl.sslMode(sslMode); return this; } /** * Configures TLS versions, see {@link io.asyncer.r2dbc.mysql.constant.TlsVersions TlsVersions}. + *

+ * Note: It is used only if the {@link #unixSocket(String)} is not configured. * * @param tlsVersion TLS versions. - * @return this {@link Builder}. + * @return {@link Builder this} * @throws IllegalArgumentException if the array {@code tlsVersion} is {@code null}. * @since 0.8.1 */ public Builder tlsVersion(String... tlsVersion) { requireNonNull(tlsVersion, "tlsVersion must not be null"); - int size = tlsVersion.length; - - if (size > 0) { - String[] versions = new String[size]; - System.arraycopy(tlsVersion, 0, versions, 0, size); - this.tlsVersion = versions; - } else { - this.tlsVersion = EMPTY_STRINGS; - } + this.ssl.tlsVersions(tlsVersion); return this; } /** * Configures SSL {@link HostnameVerifier}, it is available only set {@link #sslMode(SslMode)} as - * {@link SslMode#VERIFY_IDENTITY}. It is useful when server was using special Certificates or need - * special verification. + * {@link SslMode#VERIFY_IDENTITY}. It is useful when server was using special Certificates or need special + * verification. *

* Default is builtin {@link HostnameVerifier} which use RFC standards. + *

+ * Note: It is used only if the {@link #unixSocket(String)} is not configured. * * @param sslHostnameVerifier the custom {@link HostnameVerifier}. * @return this {@link Builder} @@ -734,8 +797,9 @@ public Builder tlsVersion(String... tlsVersion) { * @since 0.8.2 */ public Builder sslHostnameVerifier(HostnameVerifier sslHostnameVerifier) { - this.sslHostnameVerifier = requireNonNull(sslHostnameVerifier, - "sslHostnameVerifier must not be null"); + requireNonNull(sslHostnameVerifier, "sslHostnameVerifier must not be null"); + + this.ssl.sslHostnameVerifier(sslHostnameVerifier); return this; } @@ -744,41 +808,47 @@ public Builder sslHostnameVerifier(HostnameVerifier sslHostnameVerifier) { * {@link #sslMode(SslMode)} is configured for verify server certification. *

* Default is {@code null}, which means that the default algorithm is used for the trust manager. + *

+ * Note: It is used only if the {@link #unixSocket(String)} is not configured. * * @param sslCa an X.509 certificate chain file in PEM format. - * @return this {@link Builder}. + * @return {@link Builder this} * @since 0.8.1 */ public Builder sslCa(@Nullable String sslCa) { - this.sslCa = sslCa; + this.ssl.sslCa(sslCa); return this; } /** * Configures client SSL certificate for client authentication. *

- * The {@link #sslCert} and {@link #sslKey} must be both non-{@code null} or both {@code null}. + * It and {@link #sslKey} must be both non-{@code null} or both {@code null}. + *

+ * Note: It is used only if the {@link #unixSocket(String)} is not configured. * * @param sslCert an X.509 certificate chain file in PEM format, or {@code null} if no SSL cert. - * @return this {@link Builder}. + * @return {@link Builder this} * @since 0.8.2 */ public Builder sslCert(@Nullable String sslCert) { - this.sslCert = sslCert; + this.ssl.sslCert(sslCert); return this; } /** * Configures client SSL key for client authentication. *

- * The {@link #sslCert} and {@link #sslKey} must be both non-{@code null} or both {@code null}. + * It and {@link #sslCert} must be both non-{@code null} or both {@code null}. + *

+ * Note: It is used only if the {@link #unixSocket(String)} is not configured. * * @param sslKey a PKCS#8 private key file in PEM format, or {@code null} if no SSL key. - * @return this {@link Builder}. + * @return {@link Builder this} * @since 0.8.2 */ public Builder sslKey(@Nullable String sslKey) { - this.sslKey = sslKey; + this.ssl.sslKey(sslKey); return this; } @@ -786,39 +856,42 @@ public Builder sslKey(@Nullable String sslKey) { * Configures the password of SSL key file for client certificate authentication. *

* It will be used only if {@link #sslKey} and {@link #sslCert} non-null. + *

+ * Note: It is used only if the {@link #unixSocket(String)} is not configured. * - * @param sslKeyPassword the password of the {@link #sslKey}, or {@code null} if it's not - * password-protected. - * @return this {@link Builder}. + * @param sslKeyPassword the password of the {@link #sslKey}, or {@code null} if it's not password-protected. + * @return {@link Builder this} * @since 0.8.2 */ public Builder sslKeyPassword(@Nullable CharSequence sslKeyPassword) { - this.sslKeyPassword = sslKeyPassword; + this.ssl.sslKeyPassword(sslKeyPassword); return this; } /** - * Configures a {@link SslContextBuilder} customizer. The customizer gets applied on each SSL - * connection attempt to allow for just-in-time configuration updates. The {@link Function} gets - * called with the prepared {@link SslContextBuilder} that has all configuration options applied. The - * customizer may return the same builder or return a new builder instance to be used to build the SSL - * context. + * Configures a {@link SslContextBuilder} customizer. The customizer gets applied on each SSL connection attempt + * to allow for just-in-time configuration updates. The {@link Function} gets called with the prepared + * {@link SslContextBuilder} that has all configuration options applied. The customizer may return the same + * builder or return a new builder instance to be used to build the SSL context. + *

+ * Note: It is used only if the {@link #unixSocket(String)} is not configured. * * @param customizer customizer function * @return this {@link Builder} * @throws IllegalArgumentException if {@code customizer} is {@code null} * @since 0.8.1 */ - public Builder sslContextBuilderCustomizer( - Function customizer) { + public Builder sslContextBuilderCustomizer(Function customizer) { requireNonNull(customizer, "sslContextBuilderCustomizer must not be null"); - this.sslContextBuilderCustomizer = customizer; + this.ssl.sslContextBuilderCustomizer(customizer); return this; } /** * Configures TCP KeepAlive. + *

+ * Note: It is used only if the {@link #unixSocket(String)} is not configured. * * @param enabled whether to enable TCP KeepAlive * @return this {@link Builder} @@ -826,12 +899,14 @@ public Builder sslContextBuilderCustomizer( * @since 0.8.2 */ public Builder tcpKeepAlive(boolean enabled) { - this.tcpKeepAlive = enabled; + requireTcpSocket().tcpKeepAlive(enabled); return this; } /** * Configures TCP NoDelay. + *

+ * Note: It is used only if the {@link #unixSocket(String)} is not configured. * * @param enabled whether to enable TCP NoDelay * @return this {@link Builder} @@ -839,15 +914,14 @@ public Builder tcpKeepAlive(boolean enabled) { * @since 0.8.2 */ public Builder tcpNoDelay(boolean enabled) { - this.tcpNoDelay = enabled; + requireTcpSocket().tcpNoDelay(enabled); return this; } /** * Configures the protocol of parametrized statements to the text protocol. *

- * The text protocol is default protocol that's using client-preparing. See also MySQL - * documentations. + * The text protocol is default protocol that's using client-preparing. See also MySQL documentations. * * @return this {@link Builder} * @since 0.8.1 @@ -860,10 +934,9 @@ public Builder useClientPrepareStatement() { /** * Configures the protocol of parametrized statements to the binary protocol. *

- * The binary protocol is compact protocol that's using server-preparing. See also MySQL - * documentations. + * The binary protocol is compact protocol that's using server-preparing. See also MySQL documentations. * - * @return this {@link Builder}. + * @return {@link Builder this} * @since 0.8.1 */ public Builder useServerPrepareStatement() { @@ -871,19 +944,18 @@ public Builder useServerPrepareStatement() { } /** - * Configures the protocol of parametrized statements and prepare-preferred simple statements to the - * binary protocol. + * Configures the protocol of parametrized statements and prepare-preferred simple statements to the binary + * protocol. *

- * The {@code preferPrepareStatement} configures whether to prefer prepare execution on a - * statement-by-statement basis (simple statements). The {@link Predicate} accepts the simple SQL - * query string and returns a boolean flag indicating preference. {@code true} prepare-preferred, - * {@code false} prefers direct execution (text protocol). Defaults to direct execution. + * The {@code preferPrepareStatement} configures whether to prefer prepare execution on a statement-by-statement + * basis (simple statements). The {@link Predicate} accepts the simple SQL query string and returns a boolean + * flag indicating preference. {@code true} prepare-preferred, {@code false} prefers direct execution (text + * protocol). Defaults to direct execution. *

- * The binary protocol is compact protocol that's using server-preparing. See also MySQL - * documentations. + * The binary protocol is compact protocol that's using server-preparing. See also MySQL documentations. * * @param preferPrepareStatement the above {@link Predicate}. - * @return this {@link Builder}. + * @return {@link Builder this} * @throws IllegalArgumentException if {@code preferPrepareStatement} is {@code null}. * @since 0.8.1 */ @@ -895,8 +967,8 @@ public Builder useServerPrepareStatement(Predicate preferPrepareStatemen } /** - * Configures the session variables, used to set session variables immediately after login. Default no - * session variables to set. It should be a list of key-value pairs. e.g. + * Configures the session variables, used to set session variables immediately after login. Default no session + * variables to set. It should be a list of key-value pairs. e.g. * {@code ["sql_mode='ANSI_QUOTES,STRICT_TRANS_TABLES'", "time_zone=00:00"]}. * * @param sessionVariables the session variables to set. @@ -912,8 +984,8 @@ public Builder sessionVariables(String... sessionVariables) { } /** - * Configures to allow the {@code LOAD DATA LOCAL INFILE} statement in the given {@code path} or - * disallow the statement. Default to {@code null} which means not allow the statement. + * Configures to allow the {@code LOAD DATA LOCAL INFILE} statement in the given {@code path} or disallow the + * statement. Default to {@code null} which means not allow the statement. * * @param path which parent path are allowed to load file data, {@code null} means not be allowed. * @return {@link Builder this}. @@ -944,14 +1016,14 @@ public Builder localInfileBufferSize(int localInfileBufferSize) { } /** - * Configures the maximum size of the {@link Query} parsing cache. Usually it should be power of two. - * Default to {@code 0}. Driver will use unbounded cache if size is less than {@code 0}. + * Configures the maximum size of the {@link Query} parsing cache. Usually it should be power of two. Default to + * {@code 0}. Driver will use unbounded cache if size is less than {@code 0}. *

- * Notice: the cache is using EL model (the PACELC theorem) which provider better performance. That - * means it is an elastic cache. So this size is not a hard-limit. It should be over 16 in average. + * Notice: the cache is using EL model (the PACELC theorem) which provider better performance. That means it is + * an elastic cache. So this size is not a hard-limit. It should be over 16 in average. * * @param queryCacheSize the above size, {@code 0} means no cache, {@code -1} means unbounded cache. - * @return this {@link Builder}. + * @return {@link Builder this} * @since 0.8.3 */ public Builder queryCacheSize(int queryCacheSize) { @@ -960,19 +1032,17 @@ public Builder queryCacheSize(int queryCacheSize) { } /** - * Configures the maximum size of the server-preparing cache. Usually it should be power of two. - * Default to {@code 256}. Driver will use unbounded cache if size is less than {@code 0}. It is used - * only if using server-preparing parametrized statements, i.e. the {@link #useServerPrepareStatement} - * is set. + * Configures the maximum size of the server-preparing cache. Usually it should be power of two. Default to + * {@code 256}. Driver will use unbounded cache if size is less than {@code 0}. It is used only if using + * server-preparing parametrized statements, i.e. the {@link #useServerPrepareStatement} is set. *

- * Notice: the cache is using EC model (the PACELC theorem) for ensure consistency. Consistency is - * very important because MySQL contains a hard limit of all server-prepared statements which has been - * opened, see also {@code max_prepared_stmt_count}. And, the cache is one-to-one connection, which - * means it will not work on thread-concurrency. + * Notice: the cache is using EC model (the PACELC theorem) for ensure consistency. Consistency is very + * important because MySQL contains a hard limit of all server-prepared statements which has been opened, see + * also {@code max_prepared_stmt_count}. And, the cache is one-to-one connection, which means it will not work + * on thread-concurrency. * - * @param prepareCacheSize the above size, {@code 0} means no cache, {@code -1} means unbounded - * cache. - * @return this {@link Builder}. + * @param prepareCacheSize the above size, {@code 0} means no cache, {@code -1} means unbounded cache. + * @return {@link Builder this} * @since 0.8.3 */ public Builder prepareCacheSize(int prepareCacheSize) { @@ -983,10 +1053,9 @@ public Builder prepareCacheSize(int prepareCacheSize) { /** * Configures the compression algorithms. Default to [{@link CompressionAlgorithm#UNCOMPRESSED}]. *

- * It will auto choose an algorithm that's contained in the list and supported by the server, - * preferring zstd, then zlib. If the list does not contain {@link CompressionAlgorithm#UNCOMPRESSED} - * and the server does not support any algorithm in the list, an exception will be thrown when - * connecting. + * It will auto choose an algorithm that's contained in the list and supported by the server, preferring zstd, + * then zlib. If the list does not contain {@link CompressionAlgorithm#UNCOMPRESSED} and the server does not + * support any algorithm in the list, an exception will be thrown when connecting. *

* Note: zstd requires a dependency {@code com.github.luben:zstd-jni}. * @@ -1043,12 +1112,12 @@ public Builder zstdCompressionLevel(int level) { * {@link TcpResources#get() global tcp resources}. * * @param loopResources the {@link LoopResources}. - * @return this {@link Builder}. + * @return {@link Builder this} * @throws IllegalArgumentException if {@code loopResources} is {@code null}. * @since 1.1.2 */ public Builder loopResources(LoopResources loopResources) { - this.loopResources = requireNonNull(loopResources, "loopResources must not be null"); + this.client.loopResources(loopResources); return this; } @@ -1057,7 +1126,7 @@ public Builder loopResources(LoopResources loopResources) { * {@code true}. * * @param enabled to discover and register extensions. - * @return this {@link Builder}. + * @return {@link Builder this} * @since 0.8.2 */ public Builder autodetectExtensions(boolean enabled) { @@ -1068,12 +1137,12 @@ public Builder autodetectExtensions(boolean enabled) { /** * Registers a {@link Extension} to extend driver functionality and manually. *

- * Notice: the driver will not deduplicate {@link Extension}s of autodetect discovered and manually - * extended. So if a {@link Extension} is registered by this function and autodetect discovered, it - * will get two {@link Extension} as same. + * Notice: the driver will not deduplicate {@link Extension}s of autodetect discovered and manually extended. So + * if a {@link Extension} is registered by this function and autodetect discovered, it will get two + * {@link Extension} as same. * * @param extension extension to extend driver functionality. - * @return this {@link Builder}. + * @return {@link Builder this} * @throws IllegalArgumentException if {@code extension} is {@code null}. * @since 0.8.2 */ @@ -1083,26 +1152,36 @@ public Builder extendWith(Extension extension) { } /** - * Registers a password publisher function. + * Registers a password publisher function. Since 1.2.0, it is an alias of {@link #password(Publisher)}. * - * @param passwordPublisher function to retrieve password before making connection. - * @return this {@link Builder}. + * @param password a {@link Publisher} to retrieve password before making connection. + * @return {@link Builder this} */ - public Builder passwordPublisher(Publisher passwordPublisher) { - this.passwordPublisher = passwordPublisher; - return this; + public Builder passwordPublisher(Publisher password) { + return password(password); } - private SslMode requireSslMode() { - SslMode sslMode = this.sslMode; + private TcpSocketConfiguration.Builder requireTcpSocket() { + TcpSocketConfiguration.Builder tcpSocket = this.tcpSocket; - if (sslMode == null) { - sslMode = isHost ? SslMode.PREFERRED : SslMode.DISABLED; + if (tcpSocket == null) { + this.tcpSocket = tcpSocket = new TcpSocketConfiguration.Builder(); } - return sslMode; + return tcpSocket; } - private Builder() { } + private UnixDomainSocketConfiguration.Builder requireUnixSocket() { + UnixDomainSocketConfiguration.Builder unixSocket = this.unixSocket; + + if (unixSocket == null) { + this.unixSocket = unixSocket = new UnixDomainSocketConfiguration.Builder(); + } + + return unixSocket; + } + + private Builder() { + } } } diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactory.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactory.java index 9e269eda5..0881b877d 100644 --- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactory.java +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactory.java @@ -20,29 +20,20 @@ import io.asyncer.r2dbc.mysql.cache.Caches; import io.asyncer.r2dbc.mysql.cache.PrepareCache; import io.asyncer.r2dbc.mysql.cache.QueryCache; -import io.asyncer.r2dbc.mysql.client.Client; import io.asyncer.r2dbc.mysql.codec.Codecs; import io.asyncer.r2dbc.mysql.codec.CodecsBuilder; -import io.asyncer.r2dbc.mysql.constant.CompressionAlgorithm; -import io.asyncer.r2dbc.mysql.constant.SslMode; import io.asyncer.r2dbc.mysql.extension.CodecRegistrar; import io.asyncer.r2dbc.mysql.internal.util.StringUtils; import io.netty.buffer.ByteBufAllocator; -import io.netty.channel.unix.DomainSocketAddress; import io.r2dbc.spi.ConnectionFactory; import io.r2dbc.spi.ConnectionFactoryMetadata; import org.jetbrains.annotations.Nullable; -import org.reactivestreams.Publisher; import reactor.core.publisher.Mono; -import java.net.InetSocketAddress; -import java.net.SocketAddress; import java.time.ZoneId; import java.time.ZoneOffset; import java.util.ArrayList; import java.util.List; -import java.util.Objects; -import java.util.Set; import java.util.concurrent.locks.ReentrantLock; import java.util.function.Predicate; @@ -77,105 +68,38 @@ public ConnectionFactoryMetadata getMetadata() { */ public static MySqlConnectionFactory from(MySqlConnectionConfiguration configuration) { requireNonNull(configuration, "configuration must not be null"); - LazyQueryCache queryCache = new LazyQueryCache(configuration.getQueryCacheSize()); - return new MySqlConnectionFactory(Mono.defer(() -> { - MySqlSslConfiguration ssl; - SocketAddress address; - - if (configuration.isHost()) { - ssl = configuration.getSsl(); - address = InetSocketAddress.createUnresolved(configuration.getDomain(), - configuration.getPort()); - } else { - ssl = MySqlSslConfiguration.disabled(); - address = new DomainSocketAddress(configuration.getDomain()); - } - - String database = configuration.getDatabase(); - boolean createDbIfNotExist = configuration.isCreateDatabaseIfNotExist(); - String user = configuration.getUser(); - CharSequence password = configuration.getPassword(); - SslMode sslMode = ssl.getSslMode(); - int zstdCompressionLevel = configuration.getZstdCompressionLevel(); - ZoneId connectionTimeZone = retrieveZoneId(configuration.getConnectionTimeZone()); - ConnectionContext context = new ConnectionContext( - configuration.getZeroDateOption(), - configuration.getLoadLocalInfilePath(), - configuration.getLocalInfileBufferSize(), - configuration.isPreserveInstants(), - connectionTimeZone - ); - Set compressionAlgorithms = configuration.getCompressionAlgorithms(); - Extensions extensions = configuration.getExtensions(); - Predicate prepare = configuration.getPreferPrepareStatement(); - int prepareCacheSize = configuration.getPrepareCacheSize(); - Publisher passwordPublisher = configuration.getPasswordPublisher(); - boolean forceTimeZone = configuration.isForceConnectionTimeZoneToSession(); - List sessionVariables = forceTimeZone && connectionTimeZone != null ? - mergeSessionVariables(configuration.getSessionVariables(), connectionTimeZone) : - configuration.getSessionVariables(); - - if (Objects.nonNull(passwordPublisher)) { - return Mono.from(passwordPublisher).flatMap(token -> getMySqlConnection( - configuration, queryCache, - ssl, address, - database, createDbIfNotExist, - user, sslMode, - compressionAlgorithms, zstdCompressionLevel, - context, extensions, sessionVariables, prepare, - prepareCacheSize, token - )); - } - - return getMySqlConnection( - configuration, queryCache, - ssl, address, - database, createDbIfNotExist, - user, sslMode, - compressionAlgorithms, zstdCompressionLevel, - context, extensions, sessionVariables, prepare, - prepareCacheSize, password - ); - })); + return new MySqlConnectionFactory(Mono.defer(() -> connectWithInit(configuration, queryCache))); } - private static Mono getMySqlConnection( - final MySqlConnectionConfiguration configuration, - final LazyQueryCache queryCache, - final MySqlSslConfiguration ssl, - final SocketAddress address, - final String database, - final boolean createDbIfNotExist, - final String user, - final SslMode sslMode, - final Set compressionAlgorithms, - final int zstdCompressionLevel, - final ConnectionContext context, - final Extensions extensions, - final List sessionVariables, - @Nullable final Predicate prepare, - final int prepareCacheSize, - @Nullable final CharSequence password) { - return Client.connect(ssl, address, configuration.isTcpKeepAlive(), configuration.isTcpNoDelay(), - context, configuration.getConnectTimeout(), configuration.getLoopResources()) - .flatMap(client -> { - // Lazy init database after handshake/login - String db = createDbIfNotExist ? "" : database; - return QueryFlow.login(client, sslMode, db, user, password, compressionAlgorithms, - zstdCompressionLevel, context); - }) + private static Mono connectWithInit( + MySqlConnectionConfiguration configuration, + LazyQueryCache queryCache + ) { + return configuration.getSocket() + .strategy(configuration) + .connect() .flatMap(client -> { + String database = configuration.getDatabase(); + boolean createDbIfNotExist = configuration.isCreateDatabaseIfNotExist(); + ZoneId connectionTimeZone = retrieveZoneId(configuration.getConnectionTimeZone()); + Predicate prepare = configuration.getPreferPrepareStatement(); + int prepareCacheSize = configuration.getPrepareCacheSize(); + boolean forceTimeZone = configuration.isForceConnectionTimeZoneToSession(); + List sessionVariables = forceTimeZone && connectionTimeZone != null ? + mergeSessionVariables(configuration.getSessionVariables(), connectionTimeZone) : + configuration.getSessionVariables(); ByteBufAllocator allocator = client.getByteBufAllocator(); CodecsBuilder builder = Codecs.builder(); PrepareCache prepareCache = Caches.createPrepareCache(prepareCacheSize); String db = createDbIfNotExist ? database : ""; + Extensions extensions = configuration.getExtensions(); extensions.forEach(CodecRegistrar.class, registrar -> registrar.register(allocator, builder)); - return MySqlSimpleConnection.init(client, builder.build(), context, db, queryCache.get(), + return MySqlSimpleConnection.init(client, builder.build(), db, queryCache.get(), prepareCache, sessionVariables, prepare); }); } diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactoryProvider.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactoryProvider.java index 652bfd5fe..b14a97646 100644 --- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactoryProvider.java +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactoryProvider.java @@ -17,8 +17,12 @@ package io.asyncer.r2dbc.mysql; import io.asyncer.r2dbc.mysql.constant.CompressionAlgorithm; +import io.asyncer.r2dbc.mysql.constant.ProtocolDriver; +import io.asyncer.r2dbc.mysql.constant.HaProtocol; import io.asyncer.r2dbc.mysql.constant.SslMode; import io.asyncer.r2dbc.mysql.constant.ZeroDateOption; +import io.asyncer.r2dbc.mysql.internal.NodeAddress; +import io.asyncer.r2dbc.mysql.internal.util.AddressUtils; import io.netty.handler.ssl.SslContextBuilder; import io.r2dbc.spi.ConnectionFactory; import io.r2dbc.spi.ConnectionFactoryOptions; @@ -44,6 +48,7 @@ import static io.r2dbc.spi.ConnectionFactoryOptions.HOST; import static io.r2dbc.spi.ConnectionFactoryOptions.PASSWORD; import static io.r2dbc.spi.ConnectionFactoryOptions.PORT; +import static io.r2dbc.spi.ConnectionFactoryOptions.PROTOCOL; import static io.r2dbc.spi.ConnectionFactoryOptions.SSL; import static io.r2dbc.spi.ConnectionFactoryOptions.USER; @@ -52,11 +57,6 @@ */ public final class MySqlConnectionFactoryProvider implements ConnectionFactoryProvider { - /** - * The name of the driver used for discovery, should not be changed. - */ - public static final String MYSQL_DRIVER = "mysql"; - /** * Option to set the Unix Domain Socket. * @@ -65,8 +65,8 @@ public final class MySqlConnectionFactoryProvider implements ConnectionFactoryPr public static final Option UNIX_SOCKET = Option.valueOf("unixSocket"); /** - * Option to set the time zone conversion. Default to {@code true} means enable conversion between JVM - * and {@link #CONNECTION_TIME_ZONE}. + * Option to set the time zone conversion. Default to {@code true} means enable conversion between JVM and + * {@link #CONNECTION_TIME_ZONE}. *

* Note: disable it will ignore the time zone of connection, and use the JVM local time zone. * @@ -75,9 +75,9 @@ public final class MySqlConnectionFactoryProvider implements ConnectionFactoryPr public static final Option PRESERVE_INSTANTS = Option.valueOf("preserveInstants"); /** - * Option to set the time zone of connection. Default to {@code LOCAL} means use JVM local time zone. - * It should be {@code "LOCAL"}, {@code "SERVER"}, or a valid ID of {@code ZoneId}. {@code "SERVER"} means - * querying the server-side timezone during initialization. + * Option to set the time zone of connection. Default to {@code LOCAL} means use JVM local time zone. It should be + * {@code "LOCAL"}, {@code "SERVER"}, or a valid ID of {@code ZoneId}. {@code "SERVER"} means querying the + * server-side timezone during initialization. * * @since 1.1.2 */ @@ -86,8 +86,8 @@ public final class MySqlConnectionFactoryProvider implements ConnectionFactoryPr /** * Option to force the time zone of connection to session time zone. Default to {@code false}. *

- * Note: alter the time zone of session will affect the results of MySQL date/time functions, e.g. - * {@code NOW([n])}, {@code CURRENT_TIME([n])}, {@code CURRENT_DATE()}, etc. Please use with caution. + * Note: alter the time zone of session will affect the results of MySQL date/time functions, e.g. {@code NOW([n])}, + * {@code CURRENT_TIME([n])}, {@code CURRENT_DATE()}, etc. Please use with caution. * * @since 1.1.2 */ @@ -95,8 +95,7 @@ public final class MySqlConnectionFactoryProvider implements ConnectionFactoryPr Option.valueOf("forceConnectionTimeZoneToSession"); /** - * Option to set {@link ZoneId} of server. If it is set, driver will ignore the real time zone of - * server-side. + * Option to set {@link ZoneId} of server. If it is set, driver will ignore the real time zone of server-side. * * @since 0.8.2 * @deprecated since 1.1.2, use {@link #CONNECTION_TIME_ZONE} instead. @@ -120,8 +119,8 @@ public final class MySqlConnectionFactoryProvider implements ConnectionFactoryPr /** * Option to configure {@link HostnameVerifier}. It is available only if the {@link #SSL_MODE} set to - * {@link SslMode#VERIFY_IDENTITY}. It can be an implementation class name of {@link HostnameVerifier} - * with a public no-args constructor. + * {@link SslMode#VERIFY_IDENTITY}. It can be an implementation class name of {@link HostnameVerifier} with a public + * no-args constructor. * * @since 0.8.2 */ @@ -129,17 +128,17 @@ public final class MySqlConnectionFactoryProvider implements ConnectionFactoryPr Option.valueOf("sslHostnameVerifier"); /** - * Option to TLS versions for SslContext protocols, see also {@code TlsVersions}. Usually sorted from - * higher to lower. It can be a {@code Collection}. It can be a {@link String}, protocols will be - * split by {@code ,}. e.g. "TLSv1.2,TLSv1.1,TLSv1". + * Option to TLS versions for SslContext protocols, see also {@code TlsVersions}. Usually sorted from higher to + * lower. It can be a {@code Collection}. It can be a {@link String}, protocols will be split by {@code ,}. + * e.g. "TLSv1.2,TLSv1.1,TLSv1". * * @since 0.8.1 */ public static final Option TLS_VERSION = Option.valueOf("tlsVersion"); /** - * Option to set a PEM file of server SSL CA. It will be used to verify server certificates. And it will - * be used only if {@link #SSL_MODE} set to {@link SslMode#VERIFY_CA} or higher level. + * Option to set a PEM file of server SSL CA. It will be used to verify server certificates. And it will be used + * only if {@link #SSL_MODE} set to {@link SslMode#VERIFY_CA} or higher level. * * @since 0.8.1 */ @@ -168,8 +167,8 @@ public final class MySqlConnectionFactoryProvider implements ConnectionFactoryPr public static final Option SSL_CERT = Option.valueOf("sslCert"); /** - * Option to custom {@link SslContextBuilder}. It can be an implementation class name of {@link Function} - * with a public no-args constructor. + * Option to custom {@link SslContextBuilder}. It can be an implementation class name of {@link Function} with a + * public no-args constructor. * * @since 0.8.2 */ @@ -201,18 +200,17 @@ public final class MySqlConnectionFactoryProvider implements ConnectionFactoryPr /** * Enable server preparing for parametrized statements and prefer server preparing simple statements. *

- * The value can be a {@link Boolean}. If it is {@code true}, driver will use server preparing for - * parametrized statements and text query for simple statements. If it is {@code false}, driver will use - * client preparing for parametrized statements and text query for simple statements. + * The value can be a {@link Boolean}. If it is {@code true}, driver will use server preparing for parametrized + * statements and text query for simple statements. If it is {@code false}, driver will use client preparing for + * parametrized statements and text query for simple statements. *

- * The value can be a {@link Predicate}{@code <}{@link String}{@code >}. If it is set, driver will server - * preparing for parametrized statements, it configures whether to prefer prepare execution on a - * statement-by-statement basis (simple statements). The {@link Predicate}{@code <}{@link String}{@code >} - * accepts the simple SQL query string and returns a {@code boolean} flag indicating preference. + * The value can be a {@link Predicate}{@code <}{@link String}{@code >}. If it is set, driver will server preparing + * for parametrized statements, it configures whether to prefer prepare execution on a statement-by-statement basis + * (simple statements). The {@link Predicate}{@code <}{@link String}{@code >} accepts the simple SQL query string + * and returns a {@code boolean} flag indicating preference. *

- * The value can be a {@link String}. If it is set, driver will try to convert it to {@link Boolean} or an - * instance of {@link Predicate}{@code <}{@link String}{@code >} which use reflection with a public - * no-args constructor. + * The value can be a {@link String}. If it is set, driver will try to convert it to {@link Boolean} or an instance + * of {@link Predicate}{@code <}{@link String}{@code >} which use reflection with a public no-args constructor. * * @since 0.8.1 */ @@ -246,9 +244,9 @@ public final class MySqlConnectionFactoryProvider implements ConnectionFactoryPr /** * Option to set compression algorithms. Default to [{@link CompressionAlgorithm#UNCOMPRESSED}]. *

- * It will auto choose an algorithm that's contained in the list and supported by the server, preferring - * zstd, then zlib. If the list does not contain {@link CompressionAlgorithm#UNCOMPRESSED} and the server - * does not support any algorithm in the list, an exception will be thrown when connecting. + * It will auto choose an algorithm that's contained in the list and supported by the server, preferring zstd, then + * zlib. If the list does not contain {@link CompressionAlgorithm#UNCOMPRESSED} and the server does not support any + * algorithm in the list, an exception will be thrown when connecting. *

* Note: zstd requires a dependency {@code com.github.luben:zstd-jni}. * @@ -262,8 +260,7 @@ public final class MySqlConnectionFactoryProvider implements ConnectionFactoryPr *

* It is only used if zstd is chosen for the connection. *

- * Note: MySQL protocol does not allow to set the zlib compression level of the server, only zstd is - * configurable. + * Note: MySQL protocol does not allow to set the zlib compression level of the server, only zstd is configurable. * * @since 1.1.2 */ @@ -300,9 +297,9 @@ public final class MySqlConnectionFactoryProvider implements ConnectionFactoryPr public static final Option AUTODETECT_EXTENSIONS = Option.valueOf("autodetectExtensions"); /** - * Password Publisher function can be used to retrieve password before creating a connection. This can be - * used with Amazon RDS Aurora IAM authentication, wherein it requires token to be generated. The token is - * valid for 15 minutes, and this token will be used as password. + * Password Publisher function can be used to retrieve password before creating a connection. This can be used with + * Amazon RDS Aurora IAM authentication, wherein it requires token to be generated. The token is valid for 15 + * minutes, and this token will be used as password. */ public static final Option> PASSWORD_PUBLISHER = Option.valueOf("passwordPublisher"); @@ -316,12 +313,14 @@ public ConnectionFactory create(ConnectionFactoryOptions options) { @Override public boolean supports(ConnectionFactoryOptions options) { requireNonNull(options, "connectionFactoryOptions must not be null"); - return MYSQL_DRIVER.equals(options.getValue(DRIVER)); + + Object driver = options.getValue(DRIVER); + return driver instanceof String && ProtocolDriver.supports((String) driver); } @Override public String getDriver() { - return MYSQL_DRIVER; + return ProtocolDriver.standardDriver(); } /** @@ -338,16 +337,26 @@ static MySqlConnectionConfiguration setup(ConnectionFactoryOptions options) { .to(builder::user); mapper.optional(PASSWORD).asPassword() .to(builder::password); - mapper.optional(UNIX_SOCKET).asString() - .to(builder::unixSocket) - .otherwise(() -> setupHost(builder, mapper)); + + boolean unixSocket = mapper.optional(UNIX_SOCKET).asString() + .to(builder::unixSocket); + + if (!unixSocket) { + setupHost(builder, mapper); + } + mapper.optional(PRESERVE_INSTANTS).asBoolean() .to(builder::preserveInstants); - mapper.optional(CONNECTION_TIME_ZONE).asString() - .to(builder::connectionTimeZone) - .otherwise(() -> mapper.optional(SERVER_ZONE_ID) + + boolean connectionTimeZone = mapper.optional(CONNECTION_TIME_ZONE).asString() + .to(builder::connectionTimeZone); + + if (!connectionTimeZone) { + mapper.optional(SERVER_ZONE_ID) .as(ZoneId.class, id -> ZoneId.of(id, ZoneId.SHORT_IDS)) - .to(builder::serverZoneId)); + .to(builder::serverZoneId); + } + mapper.optional(FORCE_CONNECTION_TIME_ZONE_TO_SESSION).asBoolean() .to(builder::forceConnectionTimeZoneToSession); mapper.optional(TCP_KEEP_ALIVE).asBoolean() @@ -398,17 +407,44 @@ static MySqlConnectionConfiguration setup(ConnectionFactoryOptions options) { } /** - * Set builder of {@link MySqlConnectionConfiguration} for hostname-based address with SSL - * configurations. + * Set builder of {@link MySqlConnectionConfiguration} for hostname-based path with SSL configurations. * * @param builder the builder of {@link MySqlConnectionConfiguration}. * @param mapper the {@link OptionMapper} of {@code options}. */ private static void setupHost(MySqlConnectionConfiguration.Builder builder, OptionMapper mapper) { - mapper.requires(HOST).asString() - .to(builder::host); - mapper.optional(PORT).asInt() + boolean port = mapper.optional(PORT).asInt() .to(builder::port); + + if (port) { + // If port is set, host must be a single host. + mapper.requires(HOST).asString() + .to(builder::host); + } else { + // If port is not set, host can be a single host or multiple hosts. + // If the URI contains an underscore in the host, it will produce an incorrectly resolved host and port. + // e.g. "r2dbc:mysql://my_db:3306" will be resolved to "my_db:3306" as host and null as port. + // See https://github.com/asyncer-io/r2dbc-mysql/issues/255 + mapper.requires(HOST) + .asArray(String[].class, Function.identity(), it -> it.split(","), String[]::new) + .to(hosts -> { + if (hosts.length == 1) { + builder.host(hosts[0]); + return; + } + + for (String host : hosts) { + NodeAddress address = AddressUtils.parseAddress(host); + + builder.addHost(address.getHost(), address.getPort()); + } + }); + } + + mapper.requires(DRIVER).as(ProtocolDriver.class, ProtocolDriver::from) + .to(builder::driver); + mapper.optional(PROTOCOL).as(HaProtocol.class, HaProtocol::from) + .to(builder::protocol); mapper.optional(SSL).asBoolean() .to(isSsl -> builder.sslMode(isSsl ? SslMode.REQUIRED : SslMode.DISABLED)); mapper.optional(SSL_MODE).as(SslMode.class, id -> SslMode.valueOf(id.toUpperCase())) @@ -431,12 +467,12 @@ private static void setupHost(MySqlConnectionConfiguration.Builder builder, Opti } /** - * Splits session variables from user input. e.g. {@code sql_mode='ANSI_QUOTE,STRICT',c=d;e=f} will be - * split into {@code ["sql_mode='ANSI_QUOTE,STRICT'", "c=d", "e=f"]}. + * Splits session variables from user input. e.g. {@code sql_mode='ANSI_QUOTE,STRICT',c=d;e=f} will be split into + * {@code ["sql_mode='ANSI_QUOTE,STRICT'", "c=d", "e=f"]}. *

- * It supports escaping characters with backslash, quoted values with single or double quotes, and nested - * brackets. Priorities are: backslash in quoted > single quote = double quote > bracket, backslash - * will not be a valid escape character if it is not in a quoted value. + * It supports escaping characters with backslash, quoted values with single or double quotes, and nested brackets. + * Priorities are: backslash in quoted > single quote = double quote > bracket, backslash will not be a valid + * escape character if it is not in a quoted value. *

* Note that it does not strictly check syntax validity, so it will not throw syntax exceptions. * diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlSimpleConnection.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlSimpleConnection.java index f4a2b3746..53e659056 100644 --- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlSimpleConnection.java +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlSimpleConnection.java @@ -139,8 +139,6 @@ final class MySqlSimpleConnection implements MySqlConnection, ConnectionState { private final boolean batchSupported; - private final ConnectionContext context; - private final MySqlConnectionMetadata metadata; private volatile IsolationLevel sessionLevel; @@ -174,11 +172,12 @@ final class MySqlSimpleConnection implements MySqlConnection, ConnectionState { */ private volatile long currentLockWaitTimeout; - MySqlSimpleConnection(Client client, ConnectionContext context, Codecs codecs, IsolationLevel level, + MySqlSimpleConnection(Client client, Codecs codecs, IsolationLevel level, long lockWaitTimeout, QueryCache queryCache, PrepareCache prepareCache, @Nullable String product, @Nullable Predicate prepare) { + ConnectionContext context = client.getContext(); + this.client = client; - this.context = context; this.sessionLevel = level; this.currentLevel = level; this.codecs = codecs; @@ -227,9 +226,7 @@ public Mono commitTransaction() { @Override public MySqlBatch createBatch() { - return batchSupported ? new MySqlBatchingBatch(client, codecs, context) : - new MySqlSyntheticBatch(client, codecs, context); - + return batchSupported ? new MySqlBatchingBatch(client, codecs) : new MySqlSyntheticBatch(client, codecs); } @Override @@ -244,7 +241,7 @@ public MySqlStatement createStatement(String sql) { requireNonNull(sql, "sql must not be null"); if (sql.startsWith(PING_MARKER)) { - return new PingStatement(codecs, context, Flux.defer(this::doPingInternal)); + return new PingStatement(codecs, client.getContext(), Flux.defer(this::doPingInternal)); } Query query = queryCache.get(sql); @@ -252,22 +249,22 @@ public MySqlStatement createStatement(String sql) { if (query.isSimple()) { if (prepare != null && prepare.test(sql)) { logger.debug("Create a simple statement provided by prepare query"); - return new PrepareSimpleStatement(client, codecs, context, sql, prepareCache); + return new PrepareSimpleStatement(client, codecs, sql, prepareCache); } logger.debug("Create a simple statement provided by text query"); - return new TextSimpleStatement(client, codecs, context, sql); + return new TextSimpleStatement(client, codecs, sql); } if (prepare == null) { logger.debug("Create a parametrized statement provided by text query"); - return new TextParametrizedStatement(client, codecs, query, context); + return new TextParametrizedStatement(client, codecs, query); } logger.debug("Create a parametrized statement provided by prepare query"); - return new PrepareParametrizedStatement(client, codecs, query, context, prepareCache); + return new PrepareParametrizedStatement(client, codecs, query, prepareCache); } @Override @@ -417,7 +414,7 @@ public void resetCurrentLockWaitTimeout() { @Override public boolean isInTransaction() { - return (context.getServerStatuses() & ServerStatuses.IN_TRANSACTION) != 0; + return (client.getContext().getServerStatuses() & ServerStatuses.IN_TRANSACTION) != 0; } @Override @@ -432,6 +429,8 @@ public Mono setLockWaitTimeout(Duration timeout) { @Override public Mono setStatementTimeout(Duration timeout) { requireNonNull(timeout, "timeout must not be null"); + + final ConnectionContext context = client.getContext(); final boolean isMariaDb = context.isMariaDb(); final ServerVersion serverVersion = context.getServerVersion(); final long timeoutMs = timeout.toMillis(); @@ -461,7 +460,7 @@ private Flux doPingInternal() { } private boolean isSessionAutoCommit() { - return (context.getServerStatuses() & ServerStatuses.AUTO_COMMIT) != 0; + return (client.getContext().getServerStatuses() & ServerStatuses.AUTO_COMMIT) != 0; } /** @@ -469,7 +468,6 @@ private boolean isSessionAutoCommit() { * * @param client must be logged-in. * @param codecs the {@link Codecs}. - * @param context must be initialized. * @param database the database that should be lazy init. * @param queryCache the cache of {@link Query}. * @param prepareCache the cache of server-preparing result. @@ -478,20 +476,20 @@ private boolean isSessionAutoCommit() { * @return a {@link Mono} will emit an initialized {@link MySqlConnection}. */ static Mono init( - Client client, Codecs codecs, ConnectionContext context, String database, + Client client, Codecs codecs, String database, QueryCache queryCache, PrepareCache prepareCache, List sessionVariables, @Nullable Predicate prepare ) { Mono connection = initSessionVariables(client, sessionVariables) - .then(loadSessionVariables(client, codecs, context)) + .then(loadSessionVariables(client, codecs)) .map(data -> { ZoneId timeZone = data.timeZone; if (timeZone != null) { logger.debug("Got server time zone {} from loading session variables", timeZone); - context.setTimeZone(timeZone); + client.getContext().setTimeZone(timeZone); } - return new MySqlSimpleConnection(client, context, codecs, data.level, data.lockWaitTimeout, + return new MySqlSimpleConnection(client, codecs, data.level, data.lockWaitTimeout, queryCache, prepareCache, data.product, prepare); }); @@ -531,9 +529,8 @@ private static Mono initSessionVariables(Client client, List sessi return QueryFlow.executeVoid(client, query.toString()); } - private static Mono loadSessionVariables( - Client client, Codecs codecs, ConnectionContext context - ) { + private static Mono loadSessionVariables(Client client, Codecs codecs) { + ConnectionContext context = client.getContext(); StringBuilder query = new StringBuilder(160) .append("SELECT ") .append(transactionIsolationColumn(context)) @@ -548,7 +545,7 @@ private static Mono loadSessionVariables( handler = r -> convertSessionData(r, true); } - return new TextSimpleStatement(client, codecs, context, query.toString()) + return new TextSimpleStatement(client, codecs, query.toString()) .execute() .flatMap(handler) .last(); diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlSslConfiguration.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlSslConfiguration.java index d76662f40..00dd10cfb 100644 --- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlSslConfiguration.java +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlSslConfiguration.java @@ -17,6 +17,7 @@ package io.asyncer.r2dbc.mysql; import io.asyncer.r2dbc.mysql.constant.SslMode; +import io.asyncer.r2dbc.mysql.internal.util.InternalArrays; import io.netty.handler.ssl.SslContextBuilder; import org.jetbrains.annotations.Nullable; @@ -25,7 +26,7 @@ import java.util.Objects; import java.util.function.Function; -import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonNull; +import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.require; import static io.asyncer.r2dbc.mysql.internal.util.InternalArrays.EMPTY_STRINGS; /** @@ -106,8 +107,8 @@ public String getSslCert() { } /** - * Customizes a {@link SslContextBuilder} that customizer was specified by configuration, or do nothing if - * the customizer was not set. + * Customizes a {@link SslContextBuilder} that customizer was specified by configuration, or do nothing if the + * customizer was not set. * * @param builder the {@link SslContextBuilder}. * @return the {@code builder}. @@ -162,19 +163,87 @@ static MySqlSslConfiguration disabled() { return DISABLED; } - static MySqlSslConfiguration create(SslMode sslMode, String[] tlsVersion, - @Nullable HostnameVerifier sslHostnameVerifier, @Nullable String sslCa, @Nullable String sslKey, - @Nullable CharSequence sslKeyPassword, @Nullable String sslCert, - @Nullable Function sslContextBuilderCustomizer) { - requireNonNull(sslMode, "sslMode must not be null"); + static final class Builder { - if (sslMode == SslMode.DISABLED) { - return DISABLED; + @Nullable + private SslMode sslMode; + + private String[] tlsVersions = InternalArrays.EMPTY_STRINGS; + + @Nullable + private HostnameVerifier sslHostnameVerifier; + + @Nullable + private String sslCa; + + @Nullable + private String sslKey; + + @Nullable + private CharSequence sslKeyPassword; + + @Nullable + private String sslCert; + + @Nullable + private Function sslContextBuilderCustomizer; + + void sslMode(SslMode sslMode) { + this.sslMode = sslMode; } - requireNonNull(tlsVersion, "tlsVersion must not be null"); + void tlsVersions(String[] tlsVersions) { + int size = tlsVersions.length; + + if (size > 0) { + String[] versions = new String[size]; + System.arraycopy(tlsVersions, 0, versions, 0, size); + this.tlsVersions = versions; + } else { + this.tlsVersions = EMPTY_STRINGS; + } + } + + void sslHostnameVerifier(HostnameVerifier sslHostnameVerifier) { + this.sslHostnameVerifier = sslHostnameVerifier; + } - return new MySqlSslConfiguration(sslMode, tlsVersion, sslHostnameVerifier, sslCa, sslKey, - sslKeyPassword, sslCert, sslContextBuilderCustomizer); + void sslCa(@Nullable String sslCa) { + this.sslCa = sslCa; + } + + void sslCert(@Nullable String sslCert) { + this.sslCert = sslCert; + } + + void sslKey(@Nullable String sslKey) { + this.sslKey = sslKey; + } + + void sslKeyPassword(@Nullable CharSequence sslKeyPassword) { + this.sslKeyPassword = sslKeyPassword; + } + + void sslContextBuilderCustomizer(Function customizer) { + this.sslContextBuilderCustomizer = customizer; + } + + MySqlSslConfiguration build(boolean preferred) { + SslMode sslMode = this.sslMode; + + if (sslMode == null) { + sslMode = preferred ? SslMode.PREFERRED : SslMode.DISABLED; + } + + if (sslMode == SslMode.DISABLED) { + return DISABLED; + } + + require((sslCert == null && sslKey == null) || (sslCert != null && sslKey != null), + "sslCert and sslKey must be both null or both non-null"); + + return new MySqlSslConfiguration(sslMode, tlsVersions, sslHostnameVerifier, sslCa, sslKey, + sslKeyPassword, sslCert, sslContextBuilderCustomizer); + } } } diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlStatementSupport.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlStatementSupport.java index d976b6155..1357350a7 100644 --- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlStatementSupport.java +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlStatementSupport.java @@ -17,6 +17,7 @@ package io.asyncer.r2dbc.mysql; import io.asyncer.r2dbc.mysql.api.MySqlStatement; +import io.asyncer.r2dbc.mysql.client.Client; import io.asyncer.r2dbc.mysql.internal.util.InternalArrays; import org.jetbrains.annotations.Nullable; @@ -32,13 +33,13 @@ abstract class MySqlStatementSupport implements MySqlStatement { private static final String LAST_INSERT_ID = "LAST_INSERT_ID"; - protected final ConnectionContext context; + protected final Client client; @Nullable private String[] generatedColumns = null; - MySqlStatementSupport(ConnectionContext context) { - this.context = requireNonNull(context, "context must not be null"); + MySqlStatementSupport(Client client) { + this.client = requireNonNull(client, "client must not be null"); } @Override @@ -49,7 +50,7 @@ public final MySqlStatement returnGeneratedValues(String... columns) { if (len == 0) { this.generatedColumns = InternalArrays.EMPTY_STRINGS; - } else if (len == 1 || supportReturning(context)) { + } else if (len == 1 || supportReturning(client.getContext())) { String[] result = new String[len]; for (int i = 0; i < len; ++i) { @@ -59,7 +60,7 @@ public final MySqlStatement returnGeneratedValues(String... columns) { this.generatedColumns = result; } else { - String db = context.isMariaDb() ? "MariaDB 10.5.0 or below" : "MySQL"; + String db = client.getContext().isMariaDb() ? "MariaDB 10.5.0 or below" : "MySQL"; throw new IllegalArgumentException(db + " can have only one column"); } @@ -71,7 +72,7 @@ final String syntheticKeyName() { String[] columns = this.generatedColumns; // MariaDB should use `RETURNING` clause instead. - if (columns == null || supportReturning(this.context)) { + if (columns == null || supportReturning(client.getContext())) { return null; } @@ -85,7 +86,7 @@ final String syntheticKeyName() { final String returningIdentifiers() { String[] columns = this.generatedColumns; - if (columns == null || !supportReturning(context)) { + if (columns == null || !supportReturning(client.getContext())) { return ""; } diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlSyntheticBatch.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlSyntheticBatch.java index efc677beb..b3b4b3bab 100644 --- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlSyntheticBatch.java +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/MySqlSyntheticBatch.java @@ -37,14 +37,11 @@ final class MySqlSyntheticBatch implements MySqlBatch { private final Codecs codecs; - private final ConnectionContext context; - private final List statements = new ArrayList<>(); - MySqlSyntheticBatch(Client client, Codecs codecs, ConnectionContext context) { + MySqlSyntheticBatch(Client client, Codecs codecs) { this.client = requireNonNull(client, "client must not be null"); this.codecs = requireNonNull(codecs, "codecs must not be null"); - this.context = requireNonNull(context, "context must not be null"); } @Override @@ -56,7 +53,7 @@ public MySqlBatch add(String sql) { @Override public Flux execute() { return QueryFlow.execute(client, statements) - .map(messages -> MySqlSegmentResult.toResult(false, codecs, context, null, messages)); + .map(messages -> MySqlSegmentResult.toResult(false, codecs, client.getContext(), null, messages)); } @Override diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/OptionMapper.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/OptionMapper.java index f75a913f1..afc67a8bb 100644 --- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/OptionMapper.java +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/OptionMapper.java @@ -60,14 +60,13 @@ private Source(@Nullable T value) { this.value = value; } - Otherwise to(Consumer consumer) { + boolean to(Consumer consumer) { if (value == null) { - return Otherwise.FALL; + return false; } consumer.accept(value); - - return Otherwise.NOOP; + return true; } Source as(Class type) { @@ -268,27 +267,3 @@ private static O[] mapArray(String[] input, Function mapper, IntF return output; } } - -enum Otherwise { - - NOOP { - @Override - void otherwise(Runnable runnable) { - // Do nothing - } - }, - - FALL { - @Override - void otherwise(Runnable runnable) { - runnable.run(); - } - }; - - /** - * Invoked if the previous {@link Source} outcome did not match. - * - * @param runnable the {@link Runnable} that should be invoked. - */ - abstract void otherwise(Runnable runnable); -} diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/ParametrizedStatementSupport.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/ParametrizedStatementSupport.java index 41ea8e465..bb37b31ec 100644 --- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/ParametrizedStatementSupport.java +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/ParametrizedStatementSupport.java @@ -41,8 +41,6 @@ */ abstract class ParametrizedStatementSupport extends MySqlStatementSupport { - protected final Client client; - protected final Codecs codecs; protected final Query query; @@ -51,13 +49,12 @@ abstract class ParametrizedStatementSupport extends MySqlStatementSupport { private final AtomicBoolean executed = new AtomicBoolean(); - ParametrizedStatementSupport(Client client, Codecs codecs, Query query, ConnectionContext context) { - super(context); + ParametrizedStatementSupport(Client client, Codecs codecs, Query query) { + super(client); requireNonNull(query, "query must not be null"); require(query.getParameters() > 0, "parameters must be a positive integer"); - this.client = requireNonNull(client, "client must not be null"); this.codecs = requireNonNull(codecs, "codecs must not be null"); this.query = query; this.bindings = new Bindings(query.getParameters()); @@ -75,7 +72,7 @@ public final MySqlStatement add() { public final MySqlStatement bind(int index, Object value) { requireNonNull(value, "value must not be null"); - addBinding(index, codecs.encode(value, context)); + addBinding(index, codecs.encode(value, client.getContext())); return this; } @@ -84,7 +81,7 @@ public final MySqlStatement bind(String name, Object value) { requireNonNull(name, "name must not be null"); requireNonNull(value, "value must not be null"); - addBinding(getIndexes(name), codecs.encode(value, context)); + addBinding(getIndexes(name), codecs.encode(value, client.getContext())); return this; } diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/PrepareParametrizedStatement.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/PrepareParametrizedStatement.java index 9395a1309..b0a7ea8cc 100644 --- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/PrepareParametrizedStatement.java +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/PrepareParametrizedStatement.java @@ -37,9 +37,8 @@ final class PrepareParametrizedStatement extends ParametrizedStatementSupport { private int fetchSize = 0; - PrepareParametrizedStatement(Client client, Codecs codecs, Query query, ConnectionContext context, - PrepareCache prepareCache) { - super(client, codecs, query, context); + PrepareParametrizedStatement(Client client, Codecs codecs, Query query, PrepareCache prepareCache) { + super(client, codecs, query); this.prepareCache = prepareCache; } @@ -49,7 +48,8 @@ public Flux execute(List bindings) { StringUtils.extendReturning(query.getFormattedSql(), returningIdentifiers()), bindings, fetchSize, prepareCache )) - .map(messages -> MySqlSegmentResult.toResult(true, codecs, context, syntheticKeyName(), messages)); + .map(messages -> MySqlSegmentResult.toResult( + true, codecs, client.getContext(), syntheticKeyName(), messages)); } @Override diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/PrepareSimpleStatement.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/PrepareSimpleStatement.java index d037eda39..755faff02 100644 --- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/PrepareSimpleStatement.java +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/PrepareSimpleStatement.java @@ -40,9 +40,9 @@ final class PrepareSimpleStatement extends SimpleStatementSupport { private int fetchSize = 0; - PrepareSimpleStatement(Client client, Codecs codecs, ConnectionContext context, String sql, + PrepareSimpleStatement(Client client, Codecs codecs, String sql, PrepareCache prepareCache) { - super(client, codecs, context, sql); + super(client, codecs, sql); this.prepareCache = prepareCache; } @@ -50,7 +50,8 @@ final class PrepareSimpleStatement extends SimpleStatementSupport { public Flux execute() { return Flux.defer(() -> QueryFlow.execute(client, StringUtils.extendReturning(sql, returningIdentifiers()), BINDINGS, fetchSize, prepareCache)) - .map(messages -> MySqlSegmentResult.toResult(true, codecs, context, syntheticKeyName(), messages)); + .map(messages -> MySqlSegmentResult.toResult( + true, codecs, client.getContext(), syntheticKeyName(), messages)); } @Override diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/QueryFlow.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/QueryFlow.java index 7b100cd24..3a6b9ee32 100644 --- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/QueryFlow.java +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/QueryFlow.java @@ -107,10 +107,10 @@ final class QueryFlow { }; /** - * Execute multiple bindings of a server-preparing statement with one-by-one binary execution. The - * execution terminates with the last {@link CompleteMessage} or a {@link ErrorMessage}. If client - * receives a {@link ErrorMessage} will cancel subsequent {@link Binding}s. The exchange will be completed - * by {@link CompleteMessage} after receive the last result for the last binding. + * Execute multiple bindings of a server-preparing statement with one-by-one binary execution. The execution + * terminates with the last {@link CompleteMessage} or a {@link ErrorMessage}. If client receives a + * {@link ErrorMessage} will cancel subsequent {@link Binding}s. The exchange will be completed by + * {@link CompleteMessage} after receive the last result for the last binding. * * @param client the {@link Client} to exchange messages with. * @param sql the statement for exception tracing. @@ -133,10 +133,10 @@ static Flux> execute(Client client, String sql, List> execute( /** * Execute a simple compound query. Query execution terminates with the last {@link CompleteMessage} or a - * {@link ErrorMessage}. The {@link ErrorMessage} will emit an exception. The exchange will be completed - * by {@link CompleteMessage} after receive the last result for the last binding. + * {@link ErrorMessage}. The {@link ErrorMessage} will emit an exception. The exchange will be completed by + * {@link CompleteMessage} after receive the last result for the last binding. * * @param client the {@link Client} to exchange messages with. * @param sql the query to execute, can be contains multi-statements. @@ -172,9 +172,9 @@ static Flux> execute(Client client, String sql) { /** * Execute multiple simple compound queries with one-by-one. Query execution terminates with the last - * {@link CompleteMessage} or a {@link ErrorMessage}. The {@link ErrorMessage} will emit an exception and - * cancel subsequent statements' execution. The exchange will be completed by {@link CompleteMessage} - * after receive the last result for the last binding. + * {@link CompleteMessage} or a {@link ErrorMessage}. The {@link ErrorMessage} will emit an exception and cancel + * subsequent statements' execution. The exchange will be completed by {@link CompleteMessage} after receive the + * last result for the last binding. * * @param client the {@link Client} to exchange messages with. * @param statements bundled sql for execute. @@ -195,34 +195,37 @@ static Flux> execute(Client client, List statements) } /** - * Login a {@link Client} and receive the {@code client} after logon. It will emit an exception when - * client receives a {@link ErrorMessage}. + * Login a {@link Client} and receive the {@code client} after logon. It will emit an exception when client receives + * a {@link ErrorMessage}. * * @param client the {@link Client} to exchange messages with. * @param sslMode the {@link SslMode} defines SSL capability and behavior. * @param database the database that will be connected. - * @param user the user that will be login. - * @param password the password of the {@code user}. + * @param credential the {@link Credential} for login. * @param compressionAlgorithms the list of compression algorithms. * @param zstdCompressionLevel the zstd compression level. * @param context the {@link ConnectionContext} for initialization. * @return the messages received in response to the login exchange. */ - static Mono login(Client client, SslMode sslMode, String database, String user, - @Nullable CharSequence password, + static Mono login(Client client, SslMode sslMode, String database, Credential credential, Set compressionAlgorithms, int zstdCompressionLevel, ConnectionContext context) { - return client.exchange(new LoginExchangeable(client, sslMode, database, user, password, - compressionAlgorithms, zstdCompressionLevel, context)) - .onErrorResume(e -> client.forceClose().then(Mono.error(e))) - .then(Mono.just(client)); + return client.exchange(new LoginExchangeable( + client, + sslMode, + database, + credential.getUser(), + credential.getPassword(), + compressionAlgorithms, + zstdCompressionLevel, + context + )).onErrorResume(e -> client.forceClose().then(Mono.error(e))).then(Mono.just(client)); } /** - * Execute a simple query and return a {@link Mono} for the complete signal or error. Query execution - * terminates with the last {@link CompleteMessage} or a {@link ErrorMessage}. The {@link ErrorMessage} - * will emit an exception. The exchange will be completed by {@link CompleteMessage} after receive the - * last result for the last binding. + * Execute a simple query and return a {@link Mono} for the complete signal or error. Query execution terminates + * with the last {@link CompleteMessage} or a {@link ErrorMessage}. The {@link ErrorMessage} will emit an exception. + * The exchange will be completed by {@link CompleteMessage} after receive the last result for the last binding. *

* Note: this method does not support {@code LOCAL INFILE} due to it should be used for excepted queries. * @@ -246,8 +249,8 @@ static Mono executeVoid(Client client, String sql) { } /** - * Begins a new transaction with a {@link TransactionDefinition}. It will change current transaction - * statuses of the {@link ConnectionState}. + * Begins a new transaction with a {@link TransactionDefinition}. It will change current transaction statuses of + * the {@link ConnectionState}. * * @param client the {@link Client} to exchange messages with. * @param state the connection state for checks and sets transaction statuses. @@ -267,8 +270,8 @@ static Mono beginTransaction(Client client, ConnectionState state, boolean } /** - * Commits or rollbacks current transaction. It will recover statuses of the {@link ConnectionState} in - * the initial connection state. + * Commits or rollbacks current transaction. It will recover statuses of the {@link ConnectionState} in the initial + * connection state. * * @param client the {@link Client} to exchange messages with. * @param state the connection state for checks and resets transaction statuses. @@ -298,9 +301,9 @@ static Mono createSavepoint(Client client, ConnectionState state, String n /** * Execute a simple query statement. Query execution terminates with the last {@link CompleteMessage} or a - * {@link ErrorMessage}. The {@link ErrorMessage} will emit an exception. The exchange will be completed - * by {@link CompleteMessage} after receive the last result for the last binding. The exchange will be - * completed by {@link CompleteMessage} after receive the last result for the last binding. + * {@link ErrorMessage}. The {@link ErrorMessage} will emit an exception. The exchange will be completed by + * {@link CompleteMessage} after receive the last result for the last binding. The exchange will be completed by + * {@link CompleteMessage} after receive the last result for the last binding. * * @param client the {@link Client} to exchange messages with. * @param sql the query to execute, can be contains multi-statements. @@ -310,7 +313,8 @@ private static Flux execute0(Client client, String sql) { return client.exchange(new SimpleQueryExchangeable(sql)); } - private QueryFlow() { } + private QueryFlow() { + } } /** @@ -523,12 +527,12 @@ protected String offendingSql() { } /** - * An implementation of {@link FluxExchangeable} that considers server-preparing queries. Which contains a - * built-in state machine. + * An implementation of {@link FluxExchangeable} that considers server-preparing queries. Which contains a built-in + * state machine. *

- * It will reset a prepared statement if cache has matched it, otherwise it will prepare statement to a new - * statement ID and put the ID into the cache. If the statement ID does not exist in the cache after the last - * row sent, the ID will be closed. + * It will reset a prepared statement if cache has matched it, otherwise it will prepare statement to a new statement ID + * and put the ID into the cache. If the statement ID does not exist in the cache after the last row sent, the ID will + * be closed. */ final class PrepareExchangeable extends FluxExchangeable { @@ -813,8 +817,8 @@ private void onCompleteMessage(CompleteMessage message, SynchronousSink - * Not like other {@link FluxExchangeable}s, it is started by a server-side message, which should be an - * implementation of {@link HandshakeRequest}. + * Not like other {@link FluxExchangeable}s, it is started by a server-side message, which should be an implementation + * of {@link HandshakeRequest}. */ final class LoginExchangeable extends FluxExchangeable { diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/SimpleStatementSupport.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/SimpleStatementSupport.java index 42ba279e3..78a6ec781 100644 --- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/SimpleStatementSupport.java +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/SimpleStatementSupport.java @@ -27,16 +27,13 @@ */ abstract class SimpleStatementSupport extends MySqlStatementSupport { - protected final Client client; - protected final Codecs codecs; protected final String sql; - SimpleStatementSupport(Client client, Codecs codecs, ConnectionContext context, String sql) { - super(context); + SimpleStatementSupport(Client client, Codecs codecs, String sql) { + super(client); - this.client = requireNonNull(client, "client must not be null"); this.codecs = requireNonNull(codecs, "codecs must not be null"); this.sql = requireNonNull(sql, "sql must not be null"); } diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/SingleHostConnectionStrategy.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/SingleHostConnectionStrategy.java new file mode 100644 index 000000000..59f83e457 --- /dev/null +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/SingleHostConnectionStrategy.java @@ -0,0 +1,52 @@ +/* + * Copyright 2024 asyncer.io projects + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.asyncer.r2dbc.mysql; + +import io.asyncer.r2dbc.mysql.client.Client; +import io.asyncer.r2dbc.mysql.internal.NodeAddress; +import io.netty.channel.ChannelOption; +import reactor.core.publisher.Mono; +import reactor.netty.tcp.TcpClient; + +/** + * An implementation of {@link ConnectionStrategy} that connects to a single host. It can be wrapped to a + * FailoverClient. + */ +final class SingleHostConnectionStrategy implements ConnectionStrategy { + + private final Mono client; + + SingleHostConnectionStrategy(TcpSocketConfiguration socket, MySqlConnectionConfiguration configuration) { + this.client = configuration.getCredential().flatMap(credential -> { + NodeAddress address = socket.getFirstAddress(); + + logger.debug("Connect to a single host: {}", address); + + TcpClient tcpClient = ConnectionStrategy.createTcpClient(configuration.getClient(), true) + .option(ChannelOption.SO_KEEPALIVE, socket.isTcpKeepAlive()) + .option(ChannelOption.TCP_NODELAY, socket.isTcpNoDelay()) + .remoteAddress(address::toUnresolved); + + return ConnectionStrategy.login(tcpClient, credential, configuration); + }); + } + + @Override + public Mono connect() { + return client; + } +} diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/SocketClientConfiguration.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/SocketClientConfiguration.java new file mode 100644 index 000000000..3102e345b --- /dev/null +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/SocketClientConfiguration.java @@ -0,0 +1,95 @@ +/* + * Copyright 2024 asyncer.io projects + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.asyncer.r2dbc.mysql; + +import org.jetbrains.annotations.Nullable; +import reactor.netty.resources.LoopResources; + +import java.time.Duration; +import java.util.Objects; + +/** + * A general-purpose configuration for a socket client. The client can be a TCP client or a Unix Domain Socket client. + */ +final class SocketClientConfiguration { + + @Nullable + private final Duration connectTimeout; + + @Nullable + private final LoopResources loopResources; + + SocketClientConfiguration(@Nullable Duration connectTimeout, @Nullable LoopResources loopResources) { + this.connectTimeout = connectTimeout; + this.loopResources = loopResources; + } + + @Nullable + Duration getConnectTimeout() { + return connectTimeout; + } + + @Nullable + LoopResources getLoopResources() { + return loopResources; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof SocketClientConfiguration)) { + return false; + } + + SocketClientConfiguration that = (SocketClientConfiguration) o; + + return Objects.equals(connectTimeout, that.connectTimeout) && Objects.equals(loopResources, that.loopResources); + } + + @Override + public int hashCode() { + return 31 * Objects.hashCode(connectTimeout) + Objects.hashCode(loopResources); + } + + @Override + public String toString() { + return "Client{connectTimeout=" + connectTimeout + ", loopResources=" + loopResources + '}'; + } + + static final class Builder { + + @Nullable + private Duration connectTimeout; + + @Nullable + private LoopResources loopResources; + + void connectTimeout(@Nullable Duration connectTimeout) { + this.connectTimeout = connectTimeout; + } + + void loopResources(@Nullable LoopResources loopResources) { + this.loopResources = loopResources; + } + + SocketClientConfiguration build() { + return new SocketClientConfiguration(connectTimeout, loopResources); + } + } +} diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/SocketConfiguration.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/SocketConfiguration.java new file mode 100644 index 000000000..de317ddde --- /dev/null +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/SocketConfiguration.java @@ -0,0 +1,28 @@ +/* + * Copyright 2024 asyncer.io projects + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.asyncer.r2dbc.mysql; + +/** + * A sealed interface for socket configuration, it is also a factory for creating {@link ConnectionStrategy}. + * + * @see TcpSocketConfiguration + * @see UnixDomainSocketConfiguration + */ +interface SocketConfiguration { + + ConnectionStrategy strategy(MySqlConnectionConfiguration configuration); +} diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/TcpSocketConfiguration.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/TcpSocketConfiguration.java new file mode 100644 index 000000000..f100a687f --- /dev/null +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/TcpSocketConfiguration.java @@ -0,0 +1,235 @@ +/* + * Copyright 2024 asyncer.io projects + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.asyncer.r2dbc.mysql; + +import io.asyncer.r2dbc.mysql.constant.HaProtocol; +import io.asyncer.r2dbc.mysql.constant.ProtocolDriver; +import io.asyncer.r2dbc.mysql.internal.NodeAddress; +import io.asyncer.r2dbc.mysql.internal.util.InternalArrays; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.require; +import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonEmpty; + +/** + * A configuration for a TCP/SSL socket. + */ +final class TcpSocketConfiguration implements SocketConfiguration { + + private static final int DEFAULT_PORT = 3306; + + private final ProtocolDriver driver; + + private final HaProtocol protocol; + + private final List addresses; + + private final int retriesAllDown; + + private final boolean tcpKeepAlive; + + private final boolean tcpNoDelay; + + TcpSocketConfiguration( + ProtocolDriver driver, + HaProtocol protocol, + List addresses, + int retriesAllDown, + boolean tcpKeepAlive, + boolean tcpNoDelay + ) { + this.driver = driver; + this.protocol = protocol; + this.addresses = addresses; + this.retriesAllDown = retriesAllDown; + this.tcpKeepAlive = tcpKeepAlive; + this.tcpNoDelay = tcpNoDelay; + } + + ProtocolDriver getDriver() { + return driver; + } + + HaProtocol getProtocol() { + return protocol; + } + + NodeAddress getFirstAddress() { + if (addresses.isEmpty()) { + throw new IllegalStateException("No endpoints configured"); + } + return addresses.get(0); + } + + List getAddresses() { + return addresses; + } + + int getRetriesAllDown() { + return retriesAllDown; + } + + boolean isTcpKeepAlive() { + return tcpKeepAlive; + } + + boolean isTcpNoDelay() { + return tcpNoDelay; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof TcpSocketConfiguration)) { + return false; + } + + TcpSocketConfiguration that = (TcpSocketConfiguration) o; + + return tcpKeepAlive == that.tcpKeepAlive && + tcpNoDelay == that.tcpNoDelay && + driver == that.driver && + protocol == that.protocol && + retriesAllDown == that.retriesAllDown && + addresses.equals(that.addresses); + } + + @Override + public int hashCode() { + int result = driver.hashCode(); + + result = 31 * result + protocol.hashCode(); + result = 31 * result + addresses.hashCode(); + result = 31 * result + retriesAllDown; + result = 31 * result + (tcpKeepAlive ? 1 : 0); + + return 31 * result + (tcpNoDelay ? 1 : 0); + } + + @Override + public String toString() { + return "TCP{driver=" + driver + + ", protocol=" + protocol + + ", addresses=" + addresses + + ", retriesAllDown=" + retriesAllDown + + ", tcpKeepAlive=" + tcpKeepAlive + + ", tcpNoDelay=" + tcpNoDelay + + '}'; + } + + static final class Builder { + + private ProtocolDriver driver = ProtocolDriver.MYSQL; + + private HaProtocol protocol = HaProtocol.DEFAULT; + + private final List addresses = new ArrayList<>(); + + private String host = ""; + + private int port = DEFAULT_PORT; + + private boolean tcpKeepAlive = false; + + private boolean tcpNoDelay = true; + + private int retriesAllDown = 10; + + void driver(ProtocolDriver driver) { + this.driver = driver; + } + + void protocol(HaProtocol protocol) { + this.protocol = protocol; + } + + void host(String host) { + this.host = host; + } + + void port(int port) { + this.port = port; + } + + void addHost(String host, int port) { + this.addresses.add(new NodeAddress(host, port)); + } + + void addHost(String host) { + this.addresses.add(new NodeAddress(host)); + } + + void retriesAllDown(int retriesAllDown) { + this.retriesAllDown = retriesAllDown; + } + + void tcpKeepAlive(boolean tcpKeepAlive) { + this.tcpKeepAlive = tcpKeepAlive; + } + + void tcpNoDelay(boolean tcpNoDelay) { + this.tcpNoDelay = tcpNoDelay; + } + + TcpSocketConfiguration build() { + List addresses; + + if (this.addresses.isEmpty()) { + requireNonEmpty(host, "Either single host or multiple hosts must be configured"); + + addresses = Collections.singletonList(new NodeAddress(host, port)); + } else { + require(host.isEmpty(), "Either single host or multiple hosts must be configured"); + + addresses = InternalArrays.asImmutableList(this.addresses.toArray(new NodeAddress[0])); + } + + return new TcpSocketConfiguration( + driver, + protocol, + addresses, + retriesAllDown, + tcpKeepAlive, + tcpNoDelay); + } + } + + @Override + public ConnectionStrategy strategy(MySqlConnectionConfiguration configuration) { + switch (protocol) { + case REPLICATION: + ConnectionStrategy.logger.warn( + "R2DBC Connection cannot be set to read-only, replication protocol will use the first host"); + return new SingleHostConnectionStrategy(this, configuration); + case SEQUENTIAL: + return new MultiHostsConnectionStrategy(this, configuration, false); + case LOAD_BALANCE: + return new MultiHostsConnectionStrategy(this, configuration, true); + default: + if (ProtocolDriver.MYSQL == driver && addresses.size() == 1) { + return new SingleHostConnectionStrategy(this, configuration); + } else { + return new MultiHostsConnectionStrategy(this, configuration, false); + } + } + } +} diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/TextParametrizedStatement.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/TextParametrizedStatement.java index 88a10d1a1..64e7df6e0 100644 --- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/TextParametrizedStatement.java +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/TextParametrizedStatement.java @@ -28,14 +28,14 @@ */ final class TextParametrizedStatement extends ParametrizedStatementSupport { - TextParametrizedStatement(Client client, Codecs codecs, Query query, ConnectionContext context) { - super(client, codecs, query, context); + TextParametrizedStatement(Client client, Codecs codecs, Query query) { + super(client, codecs, query); } @Override protected Flux execute(List bindings) { - return Flux.defer(() -> QueryFlow.execute(client, query, returningIdentifiers(), - bindings)) - .map(messages -> MySqlSegmentResult.toResult(false, codecs, context, syntheticKeyName(), messages)); + return Flux.defer(() -> QueryFlow.execute(client, query, returningIdentifiers(), bindings)) + .map(messages -> MySqlSegmentResult.toResult( + false, codecs, client.getContext(), syntheticKeyName(), messages)); } } diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/TextSimpleStatement.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/TextSimpleStatement.java index a265f7af2..878d555eb 100644 --- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/TextSimpleStatement.java +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/TextSimpleStatement.java @@ -27,8 +27,8 @@ */ final class TextSimpleStatement extends SimpleStatementSupport { - TextSimpleStatement(Client client, Codecs codecs, ConnectionContext context, String sql) { - super(client, codecs, context, sql); + TextSimpleStatement(Client client, Codecs codecs, String sql) { + super(client, codecs, sql); } @Override @@ -36,6 +36,7 @@ public Flux execute() { return Flux.defer(() -> QueryFlow.execute( client, StringUtils.extendReturning(sql, returningIdentifiers()) - ).map(messages -> MySqlSegmentResult.toResult(false, codecs, context, syntheticKeyName(), messages))); + ).map(messages -> MySqlSegmentResult.toResult( + false, codecs, client.getContext(), syntheticKeyName(), messages))); } } diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/UnixDomainSocketConfiguration.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/UnixDomainSocketConfiguration.java new file mode 100644 index 000000000..7d71f1d85 --- /dev/null +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/UnixDomainSocketConfiguration.java @@ -0,0 +1,75 @@ +/* + * Copyright 2024 asyncer.io projects + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.asyncer.r2dbc.mysql; + +/** + * A configuration for a Unix Domain Socket. + */ +final class UnixDomainSocketConfiguration implements SocketConfiguration { + + private final String path; + + UnixDomainSocketConfiguration(String path) { + this.path = path; + } + + String getPath() { + return this.path; + } + + @Override + public ConnectionStrategy strategy(MySqlConnectionConfiguration configuration) { + return new UnixDomainSocketConnectionStrategy(this, configuration); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof UnixDomainSocketConfiguration)) { + return false; + } + + UnixDomainSocketConfiguration that = (UnixDomainSocketConfiguration) o; + + return path.equals(that.path); + } + + @Override + public int hashCode() { + return path.hashCode(); + } + + @Override + public String toString() { + return "UnixDomainSocket{path='" + path + "'}"; + } + + static final class Builder { + + private String path; + + void path(String path) { + this.path = path; + } + + UnixDomainSocketConfiguration build() { + return new UnixDomainSocketConfiguration(path); + } + } +} diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/UnixDomainSocketConnectionStrategy.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/UnixDomainSocketConnectionStrategy.java new file mode 100644 index 000000000..60ef8f58d --- /dev/null +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/UnixDomainSocketConnectionStrategy.java @@ -0,0 +1,48 @@ +/* + * Copyright 2024 asyncer.io projects + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.asyncer.r2dbc.mysql; + +import io.asyncer.r2dbc.mysql.client.Client; +import io.netty.channel.unix.DomainSocketAddress; +import reactor.core.publisher.Mono; +import reactor.netty.tcp.TcpClient; + +/** + * An implementation of {@link ConnectionStrategy} that connects to a Unix Domain Socket. + */ +final class UnixDomainSocketConnectionStrategy implements ConnectionStrategy { + + private final Mono client; + + UnixDomainSocketConnectionStrategy( + UnixDomainSocketConfiguration socket, + MySqlConnectionConfiguration configuration + ) { + this.client = configuration.getCredential().flatMap(credential -> { + String path = socket.getPath(); + TcpClient tcpClient = ConnectionStrategy.createTcpClient(configuration.getClient(), false) + .remoteAddress(() -> new DomainSocketAddress(path)); + + return ConnectionStrategy.login(tcpClient, credential, configuration); + }); + } + + @Override + public Mono connect() { + return client; + } +} diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/client/Client.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/client/Client.java index bf2b8a219..68611d41f 100644 --- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/client/Client.java +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/client/Client.java @@ -21,19 +21,13 @@ import io.asyncer.r2dbc.mysql.message.client.ClientMessage; import io.asyncer.r2dbc.mysql.message.server.ServerMessage; import io.netty.buffer.ByteBufAllocator; -import io.netty.channel.ChannelOption; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; -import org.jetbrains.annotations.Nullable; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.publisher.SynchronousSink; -import reactor.netty.resources.LoopResources; import reactor.netty.tcp.TcpClient; -import java.net.InetSocketAddress; -import java.net.SocketAddress; -import java.time.Duration; import java.util.function.BiConsumer; import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonNull; @@ -46,21 +40,21 @@ public interface Client { InternalLogger logger = InternalLoggerFactory.getInstance(Client.class); /** - * Perform an exchange of a request message. Calling this method while a previous exchange is active will - * return a deferred handle and queue the request until the previous exchange terminates. + * Perform an exchange of a request message. Calling this method while a previous exchange is active will return a + * deferred handle and queue the request until the previous exchange terminates. * * @param request one and only one request message for get server responses - * @param handler response handler, {@link SynchronousSink#complete()} should be called after the last - * response frame is sent to complete the stream and prevent multiple subscribers from - * consuming previous, active response streams + * @param handler response handler, {@link SynchronousSink#complete()} should be called after the last response + * frame is sent to complete the stream and prevent multiple subscribers from consuming previous, + * active response streams * @param handling response type * @return A {@link Flux} of incoming messages that ends with the end of the frame */ Flux exchange(ClientMessage request, BiConsumer> handler); /** - * Perform an exchange of multi-request messages. Calling this method while a previous exchange is active - * will return a deferred handle and queue the request until the previous exchange terminates. + * Perform an exchange of multi-request messages. Calling this method while a previous exchange is active will + * return a deferred handle and queue the request until the previous exchange terminates. * * @param exchangeable request messages and response handler * @param handling response type @@ -91,6 +85,13 @@ public interface Client { */ ByteBufAllocator getByteBufAllocator(); + /** + * Returns the {@link ConnectionContext}. + * + * @return the {@link ConnectionContext} + */ + ConnectionContext getContext(); + /** * Checks if the connection is open. * @@ -109,40 +110,19 @@ public interface Client { void loginSuccess(); /** - * Connects to {@code address} with configurations. Normally, should log-in after connected. + * Connects to a MySQL server using the provided {@link TcpClient} and {@link MySqlSslConfiguration}. * - * @param ssl the SSL configuration - * @param address socket address, may be host address, or Unix Domain Socket address - * @param tcpKeepAlive if enable the {@link ChannelOption#SO_KEEPALIVE} - * @param tcpNoDelay if enable the {@link ChannelOption#TCP_NODELAY} - * @param context the connection context - * @param connectTimeout connect timeout, or {@code null} if it has no timeout - * @param loopResources the loop resources to use + * @param tcpClient the configured TCP client + * @param ssl the SSL configuration + * @param context the connection context * @return A {@link Mono} that will emit a connected {@link Client}. - * @throws IllegalArgumentException if {@code ssl}, {@code address} or {@code context} is {@code null}. - * @throws ArithmeticException if {@code connectTimeout} milliseconds overflow as an int + * @throws IllegalArgumentException if {@code tcpClient}, {@code ssl} or {@code context} is {@code null}. */ - static Mono connect(MySqlSslConfiguration ssl, SocketAddress address, boolean tcpKeepAlive, - boolean tcpNoDelay, ConnectionContext context, @Nullable Duration connectTimeout, - LoopResources loopResources) { + static Mono connect(TcpClient tcpClient, MySqlSslConfiguration ssl, ConnectionContext context) { + requireNonNull(tcpClient, "tcpClient must not be null"); requireNonNull(ssl, "ssl must not be null"); - requireNonNull(address, "address must not be null"); requireNonNull(context, "context must not be null"); - TcpClient tcpClient = TcpClient.newConnection() - .runOn(loopResources); - - if (connectTimeout != null) { - tcpClient = tcpClient.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, - Math.toIntExact(connectTimeout.toMillis())); - } - - if (address instanceof InetSocketAddress) { - tcpClient = tcpClient.option(ChannelOption.SO_KEEPALIVE, tcpKeepAlive); - tcpClient = tcpClient.option(ChannelOption.TCP_NODELAY, tcpNoDelay); - } - - return tcpClient.remoteAddress(() -> address).connect() - .map(conn -> new ReactorNettyClient(conn, ssl, context)); + return tcpClient.connect().map(conn -> new ReactorNettyClient(conn, ssl, context)); } } diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/client/ReactorNettyClient.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/client/ReactorNettyClient.java index b9a12f3cc..81cb5f21e 100644 --- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/client/ReactorNettyClient.java +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/client/ReactorNettyClient.java @@ -240,6 +240,11 @@ public ByteBufAllocator getByteBufAllocator() { return connection.outbound().alloc(); } + @Override + public ConnectionContext getContext() { + return context; + } + @Override public boolean isConnected() { return state < ST_CLOSED && connection.channel().isOpen(); diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/constant/HaProtocol.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/constant/HaProtocol.java new file mode 100644 index 000000000..c54cd8923 --- /dev/null +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/constant/HaProtocol.java @@ -0,0 +1,93 @@ +/* + * Copyright 2024 asyncer.io projects + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.asyncer.r2dbc.mysql.constant; + +import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonNull; + +/** + * Failover and High-availability protocol. + *

+ * The reconnect behavior is affected by the {@code autoReconnect} option. + */ +public enum HaProtocol { + + /** + * Connecting: I want to connect sequentially until the first available node is found if multiple nodes are + * provided, otherwise connect to the single node. + *

+ * Using: I want to get back to the first node if either {@code secondsBeforeRetryPrimaryHost} or + * {@code queriesBeforeRetryPrimaryHost} is set, and multiple nodes are provided. + *

+ * Reconnect: I want to reconnect in the same order if the current node is not available and + * {@code autoReconnect=true}. + */ + DEFAULT(""), + + /** + * Connecting: I want to connect sequentially until the first available node is found. + *

+ * Using: I want to keep using the current node until it is not available. + *

+ * Reconnect: I want to reconnect in the same order if the current node is not available and + * {@code autoReconnect=true}. + */ + SEQUENTIAL("sequential"), + + /** + * Connecting: I want to connect in random order until the first available node is found. + *

+ * Using: I want to keep using the current node until it is not available. + *

+ * Reconnect: I want to re-randomize the order to reconnect if the current node is not available and + * {@code autoReconnect=true}. + */ + LOAD_BALANCE("loadbalance"), + + /** + * Connecting: I want to use read-write connection for the first node, and read-only connections for other nodes. + *

+ * Using: I want to use the first node for read-write if connection is set to read-write, and other nodes if + * connection is set to read-only. R2DBC can not set a {@link io.r2dbc.spi.Connection Connection} to read-only mode. + * So it will always use the first host. R2DBC does not recommend this mutability. Perhaps in the future, R2DBC will + * support using read-only mode to create a connection instead of modifying an existing connection. + *

+ * Reconnect: I want to reconnect to the current node if the current node is unavailable and + * {@code autoReconnect=true}. + * + * @see Proposal: add Connection.setReadonly(boolean) + */ + REPLICATION("replication"), + ; + + private final String name; + + HaProtocol(String name) { + this.name = name; + } + + public static HaProtocol from(String protocol) { + requireNonNull(protocol, "HA protocol must not be null"); + + for (HaProtocol haProtocol : HaProtocol.values()) { + if (haProtocol.name.equalsIgnoreCase(protocol)) { + return haProtocol; + } + } + + throw new IllegalArgumentException("Unknown HA protocol: " + protocol); + } +} diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/constant/ProtocolDriver.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/constant/ProtocolDriver.java new file mode 100644 index 000000000..9c228cecb --- /dev/null +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/constant/ProtocolDriver.java @@ -0,0 +1,80 @@ +/* + * Copyright 2024 asyncer.io projects + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.asyncer.r2dbc.mysql.constant; + +import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonNull; + +/** + * Enumeration of driver schemes. + */ +public enum ProtocolDriver { + + /** + * I want to use failover and high availability protocols for each host I set up. If I set a hostname that resolves + * to multiple IP addresses, the driver should pick one randomly. + *

+ * Recommended in most cases. The hostname is resolved when high availability protocols are applied. + */ + MYSQL, + + /** + * I want to use failover and high availability protocols for each IP address. If I set a hostname that resolves to + * multiple IP addresses, the driver should flatten the list and try to connect to all of IP addresses. + *

+ * The hostname is resolved before high availability protocols are applied. + */ + DNS_SRV; + + /** + * Default protocol driver name. + */ + private static final String STANDARD_NAME = "mysql"; + + /** + * DNS SRV protocol driver name. + */ + private static final String DNS_SRV_NAME = "mysql+srv"; + + public static String standardDriver() { + return STANDARD_NAME; + } + + public static boolean supports(String driverName) { + requireNonNull(driverName, "driverName must not be null"); + + switch (driverName) { + case STANDARD_NAME: + case DNS_SRV_NAME: + return true; + default: + return false; + } + } + + public static ProtocolDriver from(String driverName) { + requireNonNull(driverName, "driverName must not be null"); + + switch (driverName) { + case STANDARD_NAME: + return MYSQL; + case DNS_SRV_NAME: + return DNS_SRV; + default: + throw new IllegalArgumentException("Unknown driver name: " + driverName); + } + } +} diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/internal/NodeAddress.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/internal/NodeAddress.java new file mode 100644 index 000000000..cb1afccd8 --- /dev/null +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/internal/NodeAddress.java @@ -0,0 +1,76 @@ +/* + * Copyright 2024 asyncer.io projects + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.asyncer.r2dbc.mysql.internal; + +import java.net.InetSocketAddress; + +/** + * A value object representing a host and port. It will use the default port {@code 3306} if not specified. + */ +public final class NodeAddress { + + private static final int DEFAULT_PORT = 3306; + + private final String host; + + private final int port; + + public NodeAddress(String host) { + this(host, DEFAULT_PORT); + } + + public NodeAddress(String host, int port) { + this.host = host; + this.port = port; + } + + public String getHost() { + return host; + } + + public int getPort() { + return port; + } + + public InetSocketAddress toUnresolved() { + return InetSocketAddress.createUnresolved(this.host, this.port); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof NodeAddress)) { + return false; + } + + NodeAddress that = (NodeAddress) o; + + return port == that.port && host.equals(that.host); + } + + @Override + public int hashCode() { + return 31 * host.hashCode() + port; + } + + @Override + public String toString() { + return host + ":" + port; + } +} diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/internal/util/AddressUtils.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/internal/util/AddressUtils.java index 82e41d522..422faf725 100644 --- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/internal/util/AddressUtils.java +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/internal/util/AddressUtils.java @@ -16,10 +16,13 @@ package io.asyncer.r2dbc.mysql.internal.util; +import io.asyncer.r2dbc.mysql.internal.NodeAddress; + +import java.net.InetSocketAddress; import java.util.regex.Pattern; /** - * A utility for matching host/address. + * A utility for processing host/address. */ public final class AddressUtils { @@ -31,32 +34,104 @@ public final class AddressUtils { private static final Pattern IPV6_PATTERN = Pattern.compile("^[0-9a-fA-F]{1,4}(:[0-9a-fA-F]{1,4}){7}$"); private static final Pattern IPV6_COMPRESSED_PATTERN = Pattern.compile( - "^(([0-9a-fA-F]{1,4}(:[0-9a-fA-F]{1,4}){0,5})?)::(([0-9a-fA-F]{1,4}(:[0-9a-fA-F]{1,4}){0,5})?)$"); + "^((([0-9a-fA-F]{1,4}(:[0-9a-fA-F]{1,4}){0,5})?)::(([0-9a-fA-F]{1,4}(:[0-9a-fA-F]{1,4}){0,5})?))$"); private static final int IPV6_COLONS = 7; /** * Checks if the host is an address of IP version 4. * - * @param host the host should be check. + * @param host the host should be checked. * @return if is IPv4. */ public static boolean isIpv4(String host) { - // TODO: Use faster matches instead of regex. + // Maybe use faster matches instead of regex? return IPV4_PATTERN.matcher(host).matches(); } /** * Checks if the host is an address of IP version 6. * - * @param host the host should be check. + * @param host the host should be checked. * @return if is IPv6. */ public static boolean isIpv6(String host) { - // TODO: Use faster matches instead of regex. + // Maybe use faster matches instead of regex? return IPV6_PATTERN.matcher(host).matches() || isIpv6Compressed(host); } + /** + * Parses a host to an {@link NodeAddress}, the {@code host} may contain port or not. If the {@code host} does + * not contain a valid port, the default port {@code 3306} will be used. The {@code host} can be an IPv6, IPv4 or + * host address. e.g. [::1]:3301, [::1], 127.0.0.1, host-name:3302 + *

+ * Note: It will not check if the host is a valid address. e.g. IPv6 address should be enclosed in square brackets, + * hostname should not contain an underscore, etc. + * + * @param host the {@code host} should be parsed as socket address. + * @return the parsed and unresolved {@link InetSocketAddress} + */ + public static NodeAddress parseAddress(String host) { + int len = host.length(); + int index; + + for (index = len - 1; index > 0; --index) { + char ch = host.charAt(index); + + if (ch == ':') { + break; + } else if (ch < '0' || ch > '9') { + return new NodeAddress(host); + } + } + + if (index == 0) { + // index == 0, no host before number whatever host[0] is a colon or not, may be a hostname "a1234" + return new NodeAddress(host); + } + + int colonLen = len - index; + + if (colonLen < 2 || colonLen > 6) { + // 1. no port after colon, not a port, may be an IPv6 address like "::" + // 2. length of port > 5, max port is 65535, invalid port + return new NodeAddress(host); + } + + if (host.charAt(index - 1) == ']' && host.charAt(0) == '[') { + // Seems like an IPv6 with port + if (index <= 2) { + // Host/Address must not be empty + return new NodeAddress(host); + } + + int port = parsePort(host, index + 1, len); + + if (port > 0xFFFF) { + return new NodeAddress(host); + } + + return new NodeAddress(host.substring(0, index), port); + } + + int colonIndex = index; + + // IPv4 or host should not contain a colon, IPv6 should be enclosed in square brackets + for (--index; index >= 0; --index) { + if (host.charAt(index) == ':') { + return new NodeAddress(host); + } + } + + int port = parsePort(host, colonIndex + 1, len); + + if (port > 0xFFFF) { + return new NodeAddress(host); + } + + return new NodeAddress(host.substring(0, colonIndex), port); + } + private static boolean isIpv6Compressed(String host) { int length = host.length(); int colons = 0; @@ -67,9 +142,20 @@ private static boolean isIpv6Compressed(String host) { } } - // TODO: Use faster matches instead of regex. + // Maybe use faster matches instead of regex? return colons <= IPV6_COLONS && IPV6_COMPRESSED_PATTERN.matcher(host).matches(); } - private AddressUtils() { } + private static int parsePort(String input, int start, int end) { + int r = 0; + + for (int i = start; i < end; ++i) { + r = r * 10 + (input.charAt(i) - '0'); + } + + return r; + } + + private AddressUtils() { + } } diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/internal/util/InternalArrays.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/internal/util/InternalArrays.java index 7b73186d0..d78009116 100644 --- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/internal/util/InternalArrays.java +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/internal/util/InternalArrays.java @@ -252,7 +252,7 @@ public T[] toArray(T[] a) { return (T[]) Arrays.copyOf(source, source.length, a.getClass()); } - System.arraycopy(source, 0, a, 0, this.a.length); + System.arraycopy(source, 0, a, 0, source.length); if (a.length > source.length) { a[source.length] = null; diff --git a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/HaProtocolIntegrationTest.java b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/HaProtocolIntegrationTest.java new file mode 100644 index 000000000..b2ead6cb5 --- /dev/null +++ b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/HaProtocolIntegrationTest.java @@ -0,0 +1,81 @@ +/* + * Copyright 2024 asyncer.io projects + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.asyncer.r2dbc.mysql; + +import io.asyncer.r2dbc.mysql.constant.HaProtocol; +import io.r2dbc.spi.ValidationDepth; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import java.time.Duration; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link HaProtocol}. + */ +class HaProtocolIntegrationTest { + + @ParameterizedTest + @ValueSource(strings = { "sequential", "loadbalance" }) + void anyAvailable(String protocol) { + MySqlConnectionFactory.from(configuration(HaProtocol.from(protocol), true)).create() + .flatMapMany(connection -> connection.validate(ValidationDepth.REMOTE) + .onErrorReturn(false) + .concatWith(connection.close().then(Mono.empty()))) + .as(StepVerifier::create) + .expectNext(true) + .verifyComplete(); + } + + @ParameterizedTest + @ValueSource(strings = { "replication", "" }) + void firstAvailable(String protocol) { + MySqlConnectionFactory.from(configuration(HaProtocol.from(protocol), false)).create() + .flatMapMany(connection -> connection.validate(ValidationDepth.REMOTE) + .onErrorReturn(false) + .concatWith(connection.close().then(Mono.empty()))) + .as(StepVerifier::create) + .expectNext(true) + .verifyComplete(); + } + + private MySqlConnectionConfiguration configuration(HaProtocol protocol, boolean badFirst) { + String password = System.getProperty("test.mysql.password"); + + assertThat(password).withFailMessage("Property test.mysql.password must exists and not be empty") + .isNotNull() + .isNotEmpty(); + + MySqlConnectionConfiguration.Builder builder = MySqlConnectionConfiguration.builder() + .protocol(protocol) + .connectTimeout(Duration.ofSeconds(3)) + .user("root") + .password(password) + .database("r2dbc"); + + if (badFirst) { + builder.addHost("127.0.0.1", 3310).addHost("127.0.0.1"); + } else { + builder.addHost("127.0.0.1").addHost("127.0.0.1", 3310); + } + + return builder.build(); + } +} diff --git a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlBatchingBatchTest.java b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlBatchingBatchTest.java index 2eaab1e9b..ef764b18d 100644 --- a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlBatchingBatchTest.java +++ b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlBatchingBatchTest.java @@ -29,8 +29,7 @@ */ class MySqlBatchingBatchTest { - private final MySqlBatchingBatch batch = new MySqlBatchingBatch(mock(Client.class), mock(Codecs.class), - ConnectionContextTest.mock()); + private final MySqlBatchingBatch batch = new MySqlBatchingBatch(mock(Client.class), mock(Codecs.class)); @Test void add() { @@ -62,8 +61,7 @@ void badAdd() { @Test void addNothing() { - final MySqlBatchingBatch batch = new MySqlBatchingBatch(mock(Client.class), mock(Codecs.class), - ConnectionContextTest.mock()); + final MySqlBatchingBatch batch = new MySqlBatchingBatch(mock(Client.class), mock(Codecs.class)); assertEquals(batch.getSql(), ""); } } diff --git a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlConnectionConfigurationTest.java b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlConnectionConfigurationTest.java index 717ecaa0f..1c0a70003 100644 --- a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlConnectionConfigurationTest.java +++ b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlConnectionConfigurationTest.java @@ -17,23 +17,29 @@ package io.asyncer.r2dbc.mysql; import io.asyncer.r2dbc.mysql.constant.CompressionAlgorithm; +import io.asyncer.r2dbc.mysql.constant.HaProtocol; +import io.asyncer.r2dbc.mysql.constant.ProtocolDriver; import io.asyncer.r2dbc.mysql.constant.SslMode; import io.asyncer.r2dbc.mysql.constant.TlsVersions; import io.asyncer.r2dbc.mysql.constant.ZeroDateOption; import io.asyncer.r2dbc.mysql.extension.Extension; +import io.asyncer.r2dbc.mysql.internal.NodeAddress; import io.netty.handler.ssl.SslContextBuilder; -import org.assertj.core.api.ObjectAssert; +import io.r2dbc.spi.ConnectionFactoryOptions; import org.assertj.core.api.ThrowableTypeAssert; import org.jetbrains.annotations.Nullable; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; import java.time.Duration; -import java.time.ZoneId; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; import java.util.List; -import java.util.Objects; +import java.util.Optional; import java.util.function.Function; import static org.assertj.core.api.Assertions.assertThat; @@ -67,24 +73,20 @@ void invalid() { @Test void unixSocket() { for (SslMode mode : SslMode.values()) { - if (mode.startSsl()) { - assertThatIllegalArgumentException().isThrownBy(() -> unixSocketSslMode(mode)) - .withMessageContaining("sslMode"); - } else { - assertThat(unixSocketSslMode(SslMode.DISABLED)).isNotNull(); - } + assertThat(unixSocketSslMode(mode)).isNotNull(); } MySqlConnectionConfiguration configuration = MySqlConnectionConfiguration.builder() .unixSocket(UNIX_SOCKET) .user(USER) .build(); - ObjectAssert asserted = assertThat(configuration); - asserted.extracting(MySqlConnectionConfiguration::getDomain).isEqualTo(UNIX_SOCKET); - asserted.extracting(MySqlConnectionConfiguration::getUser).isEqualTo(USER); - asserted.extracting(MySqlConnectionConfiguration::isHost).isEqualTo(false); - asserted.extracting(MySqlConnectionConfiguration::getSsl) - .extracting(MySqlSslConfiguration::getSslMode).isEqualTo(SslMode.DISABLED); + + assertThat(((UnixDomainSocketConfiguration) configuration.getSocket()).getPath()).isEqualTo(UNIX_SOCKET); + assertThat(configuration.getSsl().getSslMode()).isEqualTo(SslMode.DISABLED); + configuration.getCredential() + .as(StepVerifier::create) + .expectNext(new Credential(USER, null)) + .verifyComplete(); } @Test @@ -93,12 +95,12 @@ void hosted() { .host(HOST) .user(USER) .build(); - ObjectAssert asserted = assertThat(configuration); - asserted.extracting(MySqlConnectionConfiguration::getDomain).isEqualTo(HOST); - asserted.extracting(MySqlConnectionConfiguration::getUser).isEqualTo(USER); - asserted.extracting(MySqlConnectionConfiguration::isHost).isEqualTo(true); - asserted.extracting(MySqlConnectionConfiguration::getSsl) - .extracting(MySqlSslConfiguration::getSslMode).isEqualTo(SslMode.PREFERRED); + assertThat(((TcpSocketConfiguration) configuration.getSocket()).getAddresses()) + .isEqualTo(Collections.singletonList(new NodeAddress(HOST))); + assertThat(configuration.getSsl().getSslMode()).isEqualTo(SslMode.PREFERRED); + configuration.getCredential().as(StepVerifier::create) + .expectNext(new Credential(USER, null)) + .verifyComplete(); } @Test @@ -106,18 +108,20 @@ void allSslModeHosted() { String sslCa = "/path/to/ca.pem"; for (SslMode mode : SslMode.values()) { - ObjectAssert asserted = assertThat(hostedSslMode(mode, sslCa)); + MySqlConnectionConfiguration configuration = hostedSslMode(mode, sslCa); - asserted.extracting(MySqlConnectionConfiguration::getDomain).isEqualTo(HOST); - asserted.extracting(MySqlConnectionConfiguration::getUser).isEqualTo(USER); - asserted.extracting(MySqlConnectionConfiguration::isHost).isEqualTo(true); - asserted.extracting(MySqlConnectionConfiguration::getSsl) - .extracting(MySqlSslConfiguration::getSslMode).isEqualTo(mode); + assertThat(configuration.getSsl().getSslMode()).isEqualTo(mode); + assertThat(((TcpSocketConfiguration) configuration.getSocket()).getAddresses()) + .isEqualTo(Collections.singletonList(new NodeAddress(HOST))); if (mode.startSsl()) { - asserted.extracting(MySqlConnectionConfiguration::getSsl) - .extracting(MySqlSslConfiguration::getSslCa).isSameAs(sslCa); + assertThat(configuration.getSsl().getSslCa()).isSameAs(sslCa); } + + configuration.getCredential() + .as(StepVerifier::create) + .expectNext(new Credential(USER, null)) + .verifyComplete(); } } @@ -131,13 +135,7 @@ void invalidPort() { @Test void allFillUp() { - assertThat(filledUp()).extracting(MySqlConnectionConfiguration::getSsl).isNotNull(); - } - - @Test - void isEquals() { - assertThat(filledUp()).isEqualTo(filledUp()).extracting(Objects::hashCode) - .isEqualTo(filledUp().hashCode()); + assertThat(filledUp().getSsl()).isNotNull(); } @Test @@ -194,19 +192,63 @@ void nonAutodetectExtensions() { @Test void validPasswordSupplier() { - final Mono passwordSupplier = Mono.just("123456"); + Mono passwordSupplier = Mono.just("123456"); + Mono.from(MySqlConnectionConfiguration.builder() .host(HOST) .user(USER) .passwordPublisher(passwordSupplier) .autodetectExtensions(false) .build() - .getPasswordPublisher()) + .getCredential()) .as(StepVerifier::create) - .expectNext("123456") + .expectNext(new Credential(USER, "123456")) .verifyComplete(); } + @ParameterizedTest + @ValueSource(strings = { + "r2dbc:mysql://my-db1:3309,my-db2:3310/r2dbc", + "r2dbcs:mysql://my-db1:3309,my-db2:3310/r2dbc", + "r2dbc:mysql+srv://my-db1:3309,my-db2:3310/r2dbc", + "r2dbcs:mysql+srv://my-db1:3309,my-db2:3310/r2dbc", + "r2dbc:mysql:replication://my-db1:3309,my-db2:3310/r2dbc", + "r2dbcs:mysql:replication://my-db1:3309,my-db2:3310/r2dbc", + "r2dbc:mysql+srv:replication://my-db1:3309,my-db2:3310/r2dbc", + "r2dbcs:mysql+srv:replication://my-db1:3309,my-db2:3310/r2dbc", + "r2dbc:mysql:loadbalance://my-db1:3309,my-db2:3310/r2dbc", + "r2dbcs:mysql:loadbalance://my-db1:3309,my-db2:3310/r2dbc", + "r2dbc:mysql+srv:loadbalance://my-db1:3309,my-db2:3310/r2dbc", + "r2dbcs:mysql+srv:loadbalance://my-db1:3309,my-db2:3310/r2dbc", + "r2dbc:mysql:sequential://my-db1:3309,my-db2:3310/r2dbc", + "r2dbcs:mysql:sequential://my-db1:3309,my-db2:3310/r2dbc", + "r2dbc:mysql+srv:sequential://my-db1:3309,my-db2:3310/r2dbc", + "r2dbcs:mysql+srv:sequential://my-db1:3309,my-db2:3310/r2dbc", + }) + void multipleHosts(String url) { + ConnectionFactoryOptions options = ConnectionFactoryOptions.parse(url) + .mutate() + .option(ConnectionFactoryOptions.USER, "root") + .build(); + MySqlConnectionConfiguration configuration = MySqlConnectionFactoryProvider.setup(options); + + assertThat(configuration.getSocket()).isInstanceOf(TcpSocketConfiguration.class); + + TcpSocketConfiguration tcp = (TcpSocketConfiguration) configuration.getSocket(); + + assertThat(tcp.getAddresses()).isEqualTo(Arrays.asList( + new NodeAddress("my-db1", 3309), + new NodeAddress("my-db2", 3310) + )); + assertThat(tcp.getDriver()).isEqualTo( + ProtocolDriver.from(options.getRequiredValue(ConnectionFactoryOptions.DRIVER).toString())); + assertThat(tcp.getProtocol()).isEqualTo( + HaProtocol.from(Optional.ofNullable(options.getValue(ConnectionFactoryOptions.PROTOCOL)) + .map(Object::toString) + .orElse("")) + ); + } + private static MySqlConnectionConfiguration unixSocketSslMode(SslMode sslMode) { return MySqlConnectionConfiguration.builder() .unixSocket(UNIX_SOCKET) @@ -225,6 +267,7 @@ private static MySqlConnectionConfiguration hostedSslMode(SslMode sslMode, @Null } private static MySqlConnectionConfiguration filledUp() { + // Since 1.0.5, the passwordPublisher is Mono, equals() and hashCode() are not reliable. return MySqlConnectionConfiguration.builder() .host(HOST) .user(USER) diff --git a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactoryProviderTest.java b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactoryProviderTest.java index 41b2ef45a..3e6e8cb3d 100644 --- a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactoryProviderTest.java +++ b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactoryProviderTest.java @@ -17,8 +17,10 @@ package io.asyncer.r2dbc.mysql; import io.asyncer.r2dbc.mysql.constant.CompressionAlgorithm; +import io.asyncer.r2dbc.mysql.constant.HaProtocol; import io.asyncer.r2dbc.mysql.constant.SslMode; import io.asyncer.r2dbc.mysql.constant.ZeroDateOption; +import io.asyncer.r2dbc.mysql.internal.NodeAddress; import io.netty.handler.ssl.SslContextBuilder; import io.r2dbc.spi.ConnectionFactories; import io.r2dbc.spi.ConnectionFactoryOptions; @@ -31,6 +33,7 @@ import org.junit.jupiter.params.provider.ValueSource; import org.reactivestreams.Publisher; import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; import javax.net.ssl.HostnameVerifier; import javax.net.ssl.SSLSession; @@ -39,7 +42,6 @@ import java.lang.reflect.Modifier; import java.net.URLEncoder; import java.time.Duration; -import java.time.ZoneId; import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -57,6 +59,7 @@ import static io.r2dbc.spi.ConnectionFactoryOptions.HOST; import static io.r2dbc.spi.ConnectionFactoryOptions.PASSWORD; import static io.r2dbc.spi.ConnectionFactoryOptions.PORT; +import static io.r2dbc.spi.ConnectionFactoryOptions.PROTOCOL; import static io.r2dbc.spi.ConnectionFactoryOptions.SSL; import static io.r2dbc.spi.ConnectionFactoryOptions.USER; import static org.assertj.core.api.Assertions.assertThat; @@ -94,6 +97,25 @@ void validUrl() throws UnsupportedEncodingException { "sslKeyPassword=ssl123456")).isExactlyInstanceOf(MySqlConnectionFactory.class); } + @ParameterizedTest + @ValueSource(strings = { + "r2dbc:mysql://localhost:3306", + "r2dbcs:mysql://root@localhost:3306", + "r2dbc:mysql://root@localhost:3306?unixSocket=/path/to/mysql.sock", + "r2dbcs:mysql://mysql-region-1.some-cloud.com,mysql-region-2.some-cloud.com:3307", + "r2dbc:mysql:loadbalance://mysql-region-1.some-cloud.com,mysql-region-2.some-cloud.com:3307", + "r2dbc:mysql:sequential://mysql-region-1.some-cloud.com:3306,mysql-region-2.some-cloud.com:3307", + "r2dbcs:mysql:replication://mysql-region-1.some-cloud.com:3305,mysql-region-2.some-cloud.com:3307", + "r2dbc:mysql+srv:loadbalance://mysql-region-1.some-cloud.com,mysql-region-2.some-cloud.com:3307", + "r2dbc:mysql+srv:sequential://mysql-region-1.some-cloud.com:3306,mysql-region-2.some-cloud.com:3307", + "r2dbcs:mysql+srv:replication://mysql-region-1.some-cloud.com:3305,mysql-region-2.some-cloud.com:3307", + }) + void supports(String url) { + MySqlConnectionFactoryProvider provider = new MySqlConnectionFactoryProvider(); + + assertThat(provider.supports(ConnectionFactoryOptions.parse(url))).isTrue(); + } + @Test void urlSslModeInUnixSocket() throws UnsupportedEncodingException { Assert that = assertThat(SslMode.DISABLED); @@ -135,6 +157,7 @@ void validProgrammaticHost() { options = ConnectionFactoryOptions.builder() .option(DRIVER, "mysql") + .option(PROTOCOL, "replication") .option(HOST, "127.0.0.1") .option(PORT, 3307) .option(USER, "root") @@ -161,16 +184,25 @@ void validProgrammaticHost() { MySqlConnectionConfiguration configuration = MySqlConnectionFactoryProvider.setup(options); - assertThat(configuration.getDomain()).isEqualTo("127.0.0.1"); - assertThat(configuration.isHost()).isTrue(); - assertThat(configuration.getPort()).isEqualTo(3307); - assertThat(configuration.getUser()).isEqualTo("root"); - assertThat(configuration.getPassword()).isEqualTo("123456"); - assertThat(configuration.getConnectTimeout()).isEqualTo(Duration.ofSeconds(3)); + assertThat(configuration.getSocket()).isInstanceOf(TcpSocketConfiguration.class); + + TcpSocketConfiguration tcp = (TcpSocketConfiguration) configuration.getSocket(); + + assertThat(tcp.getProtocol()) + .isEqualTo(HaProtocol.REPLICATION); + assertThat(tcp.getAddresses()) + .isEqualTo(Collections.singletonList(new NodeAddress("127.0.0.1", 3307))); + assertThat(tcp.isTcpKeepAlive()).isTrue(); + assertThat(tcp.isTcpNoDelay()).isTrue(); + + configuration.getCredential() + .as(StepVerifier::create) + .expectNext(new Credential("root", "123456")) + .verifyComplete(); + + assertThat(configuration.getClient().getConnectTimeout()).isEqualTo(Duration.ofSeconds(3)); assertThat(configuration.getDatabase()).isEqualTo("r2dbc"); assertThat(configuration.getZeroDateOption()).isEqualTo(ZeroDateOption.USE_ROUND); - assertThat(configuration.isTcpKeepAlive()).isTrue(); - assertThat(configuration.isTcpNoDelay()).isTrue(); assertThat(configuration.getConnectionTimeZone()).isEqualTo("Asia/Tokyo"); assertThat(configuration.getPreferPrepareStatement()).isExactlyInstanceOf(AllTruePredicate.class); assertThat(configuration.getExtensions()).isEqualTo(Extensions.from(Collections.emptyList(), true)); @@ -248,9 +280,7 @@ void invalidProgrammatic() { @Test void validProgrammaticUnixSocket() { - Assert domain = assertThat("/path/to/mysql.sock"); - Assert isHost = assertThat(false); - Assert sslMode = assertThat(SslMode.DISABLED); + Assert path = assertThat("/path/to/mysql.sock"); ConnectionFactoryOptions options = ConnectionFactoryOptions.builder() .option(DRIVER, "mysql") @@ -260,9 +290,11 @@ void validProgrammaticUnixSocket() { .build(); MySqlConnectionConfiguration configuration = MySqlConnectionFactoryProvider.setup(options); - domain.isEqualTo(configuration.getDomain()); - isHost.isEqualTo(configuration.isHost()); - sslMode.isEqualTo(configuration.getSsl().getSslMode()); + assertThat(configuration.getSocket()).isInstanceOf(UnixDomainSocketConfiguration.class); + + UnixDomainSocketConfiguration unix = (UnixDomainSocketConfiguration) configuration.getSocket(); + + path.isEqualTo(unix.getPath()); for (SslMode mode : SslMode.values()) { configuration = MySqlConnectionFactoryProvider.setup(ConnectionFactoryOptions.builder() @@ -272,9 +304,11 @@ void validProgrammaticUnixSocket() { .option(Option.valueOf("sslMode"), mode.name().toLowerCase()) .build()); - domain.isEqualTo(configuration.getDomain()); - isHost.isEqualTo(configuration.isHost()); - sslMode.isEqualTo(configuration.getSsl().getSslMode()); + assertThat(configuration.getSocket()).isInstanceOf(UnixDomainSocketConfiguration.class); + + unix = (UnixDomainSocketConfiguration) configuration.getSocket(); + + path.isEqualTo(unix.getPath()); } configuration = MySqlConnectionFactoryProvider.setup(ConnectionFactoryOptions.builder() @@ -303,31 +337,24 @@ void validProgrammaticUnixSocket() { .option(Option.valueOf("tcpNoDelay"), "true") .build()); - assertThat(configuration.getDomain()).isEqualTo("/path/to/mysql.sock"); - assertThat(configuration.isHost()).isFalse(); - assertThat(configuration.getPort()).isEqualTo(3306); - assertThat(configuration.getUser()).isEqualTo("root"); - assertThat(configuration.getPassword()).isEqualTo("123456"); - assertThat(configuration.getConnectTimeout()).isEqualTo(Duration.ofSeconds(3)); + assertThat(configuration.getSocket()).isInstanceOf(UnixDomainSocketConfiguration.class); + + unix = (UnixDomainSocketConfiguration) configuration.getSocket(); + + assertThat(unix.getPath()).isEqualTo("/path/to/mysql.sock"); + + configuration.getCredential() + .as(StepVerifier::create) + .expectNext(new Credential("root", "123456")) + .verifyComplete(); + + assertThat(configuration.getClient().getConnectTimeout()).isEqualTo(Duration.ofSeconds(3)); assertThat(configuration.getDatabase()).isEqualTo("r2dbc"); assertThat(configuration.isCreateDatabaseIfNotExist()).isTrue(); assertThat(configuration.getZeroDateOption()).isEqualTo(ZeroDateOption.USE_ROUND); - assertThat(configuration.isTcpKeepAlive()).isTrue(); - assertThat(configuration.isTcpNoDelay()).isTrue(); assertThat(configuration.getConnectionTimeZone()).isEqualTo("Asia/Tokyo"); assertThat(configuration.getPreferPrepareStatement()).isExactlyInstanceOf(AllTruePredicate.class); assertThat(configuration.getExtensions()).isEqualTo(Extensions.from(Collections.emptyList(), true)); - - assertThat(configuration.getSsl().getSslMode()).isEqualTo(SslMode.DISABLED); - assertThat(configuration.getSsl().getTlsVersion()).isEmpty(); - assertThat(configuration.getSsl().getSslCa()).isNull(); - assertThat(configuration.getSsl().getSslKey()).isNull(); - assertThat(configuration.getSsl().getSslCert()).isNull(); - assertThat(configuration.getSsl().getSslKeyPassword()).isNull(); - assertThat(configuration.getSsl().getSslHostnameVerifier()).isNull(); - SslContextBuilder sslContextBuilder = SslContextBuilder.forClient(); - assertThat(sslContextBuilder) - .isSameAs(configuration.getSsl().customizeSslContext(sslContextBuilder)); } @Test @@ -455,14 +482,13 @@ void validPasswordSupplier() { @Test void allConfigurationOptions() { - List exceptConfigs = Arrays.asList( + List exceptConfigs = Arrays.asList( "extendWith", "username", + "addHost", "zeroDateOption"); List exceptOptions = Arrays.asList( - "driver", "ssl", - "protocol", "zeroDate", "lockWaitTimeout", "statementTimeout"); diff --git a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlSimpleConnectionTest.java b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlSimpleConnectionTest.java index 9fa2395b8..6a87ce508 100644 --- a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlSimpleConnectionTest.java +++ b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlSimpleConnectionTest.java @@ -39,7 +39,7 @@ */ class MySqlSimpleConnectionTest { - private final Client client = mock(Client.class); + private final Client client; private final Codecs codecs = mock(Codecs.class); @@ -47,21 +47,29 @@ class MySqlSimpleConnectionTest { private final String product = "MockConnection"; - private final MySqlSimpleConnection noPrepare = new MySqlSimpleConnection(client, ConnectionContextTest.mock(), - codecs, level, 50, Caches.createQueryCache(0), - Caches.createPrepareCache(0), product, null); + private final MySqlSimpleConnection noPrepare; + + MySqlSimpleConnectionTest() { + Client client = mock(Client.class); + + when(client.getContext()).thenReturn(ConnectionContextTest.mock()); + + this.client = client; + this.noPrepare = new MySqlSimpleConnection(client, codecs, level, 50, + Caches.createQueryCache(0), Caches.createPrepareCache(0), product, null); + } @Test void createStatement() { String condition = "SELECT * FROM test"; - MySqlSimpleConnection allPrepare = new MySqlSimpleConnection(client, ConnectionContextTest.mock(), - codecs, level, 50, Caches.createQueryCache(0), + MySqlSimpleConnection allPrepare = new MySqlSimpleConnection( + client, codecs, level, 50, Caches.createQueryCache(0), Caches.createPrepareCache(0), product, sql -> true); - MySqlSimpleConnection halfPrepare = new MySqlSimpleConnection(client, ConnectionContextTest.mock(), - codecs, level, 50, Caches.createQueryCache(0), + MySqlSimpleConnection halfPrepare = new MySqlSimpleConnection( + client, codecs, level, 50, Caches.createQueryCache(0), Caches.createPrepareCache(0), product, sql -> false); - MySqlSimpleConnection conditionPrepare = new MySqlSimpleConnection(client, ConnectionContextTest.mock(), - codecs, level, 50, Caches.createQueryCache(0), + MySqlSimpleConnection conditionPrepare = new MySqlSimpleConnection( + client, codecs, level, 50, Caches.createQueryCache(0), Caches.createPrepareCache(0), product, sql -> sql.equals(condition)); assertThat(noPrepare.createStatement("SELECT * FROM test WHERE id=1")) diff --git a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlSyntheticBatchTest.java b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlSyntheticBatchTest.java index c381c8418..ebcb4589e 100644 --- a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlSyntheticBatchTest.java +++ b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlSyntheticBatchTest.java @@ -28,8 +28,7 @@ */ class MySqlSyntheticBatchTest { - private final MySqlSyntheticBatch batch = new MySqlSyntheticBatch(mock(Client.class), mock(Codecs.class), - ConnectionContextTest.mock()); + private final MySqlSyntheticBatch batch = new MySqlSyntheticBatch(mock(Client.class), mock(Codecs.class)); @SuppressWarnings("ConstantConditions") @Test diff --git a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlTestKitSupport.java b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlTestKitSupport.java index 7b85d4150..832e56d64 100644 --- a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlTestKitSupport.java +++ b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/MySqlTestKitSupport.java @@ -17,6 +17,7 @@ package io.asyncer.r2dbc.mysql; import com.zaxxer.hikari.HikariDataSource; +import io.asyncer.r2dbc.mysql.internal.NodeAddress; import io.r2dbc.spi.test.TestKit; import org.springframework.jdbc.core.JdbcTemplate; @@ -75,15 +76,19 @@ public String clobType() { } private static JdbcTemplate jdbc(MySqlConnectionConfiguration configuration) { + TcpSocketConfiguration socket = (TcpSocketConfiguration) configuration.getSocket(); + NodeAddress address = socket.getFirstAddress(); + Credential credential = configuration.getCredential().blockOptional().orElseThrow(() -> + new IllegalStateException("Credential must be present")); HikariDataSource source = new HikariDataSource(); - source.setJdbcUrl(String.format("jdbc:mysql://%s:%d/%s", configuration.getDomain(), - configuration.getPort(), configuration.getDatabase())); - source.setUsername(configuration.getUser()); - source.setPassword(Optional.ofNullable(configuration.getPassword()) + source.setJdbcUrl(String.format("jdbc:mysql://%s:%d/%s", + address.getHost(), address.getPort(), configuration.getDatabase())); + source.setUsername(credential.getUser()); + source.setPassword(Optional.ofNullable(credential.getPassword()) .map(Object::toString).orElse(null)); source.setMaximumPoolSize(1); - source.setConnectionTimeout(Optional.ofNullable(configuration.getConnectTimeout()) + source.setConnectionTimeout(Optional.ofNullable(configuration.getClient().getConnectTimeout()) .map(Duration::toMillis).orElse(0L)); source.addDataSourceProperty("preserveInstants", configuration.isPreserveInstants()); diff --git a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/OptionMapperTest.java b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/OptionMapperTest.java index 0952c95a5..e4786bb6a 100644 --- a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/OptionMapperTest.java +++ b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/OptionMapperTest.java @@ -92,12 +92,15 @@ void otherwiseNoop() { AtomicReference ref = new AtomicReference<>(fill); AtomicReference other = new AtomicReference<>(fill); - new OptionMapper(ConnectionFactoryOptions.builder() + boolean set = new OptionMapper(ConnectionFactoryOptions.builder() .option(USER, "no-root") .build()) .requires(USER) - .to(ref::set) - .otherwise(() -> other.set(8)); + .to(ref::set); + + if (!set) { + other.set(8); + } assertThat(ref.get()).isEqualTo("no-root"); assertThat(other.get()).isSameAs(fill); @@ -109,11 +112,14 @@ void otherwiseFall() { AtomicReference ref = new AtomicReference<>(fill); AtomicReference other = new AtomicReference<>(fill); - new OptionMapper(ConnectionFactoryOptions.builder() + boolean set = new OptionMapper(ConnectionFactoryOptions.builder() .build()) .optional(USER) - .to(ref::set) - .otherwise(() -> other.set(8)); + .to(ref::set); + + if (!set) { + other.set(8); + } assertThat(ref.get()).isSameAs(fill); assertThat(other.get()).isEqualTo(8); diff --git a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/PrepareParametrizedStatementTest.java b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/PrepareParametrizedStatementTest.java index c44bc5b28..f49d476e0 100644 --- a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/PrepareParametrizedStatementTest.java +++ b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/PrepareParametrizedStatementTest.java @@ -23,14 +23,13 @@ import java.lang.reflect.Field; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; /** * Unit tests for {@link PrepareParametrizedStatement}. */ class PrepareParametrizedStatementTest implements StatementTestSupport { - private final Client client = mock(Client.class); - private final Codecs codecs = Codecs.builder().build(); private final Field fetchSize = PrepareParametrizedStatement.class.getDeclaredField("fetchSize"); @@ -46,11 +45,14 @@ public int getFetchSize(PrepareParametrizedStatement statement) throws IllegalAc @Override public PrepareParametrizedStatement makeInstance(boolean isMariaDB, String sql, String ignored) { + Client client = mock(Client.class); + + when(client.getContext()).thenReturn(ConnectionContextTest.mock(isMariaDB)); + return new PrepareParametrizedStatement( client, codecs, Query.parse(sql), - ConnectionContextTest.mock(isMariaDB), Caches.createPrepareCache(0) ); } diff --git a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/PrepareSimpleStatementTest.java b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/PrepareSimpleStatementTest.java index 947af752d..0e18e7233 100644 --- a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/PrepareSimpleStatementTest.java +++ b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/PrepareSimpleStatementTest.java @@ -23,14 +23,13 @@ import java.lang.reflect.Field; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; /** * Unit tests for {@link PrepareSimpleStatement}. */ class PrepareSimpleStatementTest implements StatementTestSupport { - private final Client client = mock(Client.class); - private final Codecs codecs = mock(Codecs.class); private final Field fetchSize = PrepareSimpleStatement.class.getDeclaredField("fetchSize"); @@ -61,10 +60,13 @@ public int getFetchSize(PrepareSimpleStatement statement) throws IllegalAccessEx @Override public PrepareSimpleStatement makeInstance(boolean isMariaDB, String ignored, String sql) { + Client client = mock(Client.class); + + when(client.getContext()).thenReturn(ConnectionContextTest.mock(isMariaDB)); + return new PrepareSimpleStatement( client, codecs, - ConnectionContextTest.mock(isMariaDB), sql, Caches.createPrepareCache(0) ); diff --git a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/ProtocolDriverIntegrationTest.java b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/ProtocolDriverIntegrationTest.java new file mode 100644 index 000000000..940d89335 --- /dev/null +++ b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/ProtocolDriverIntegrationTest.java @@ -0,0 +1,64 @@ +/* + * Copyright 2024 asyncer.io projects + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.asyncer.r2dbc.mysql; + +import io.r2dbc.spi.ConnectionFactoryOptions; +import io.r2dbc.spi.ValidationDepth; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import java.time.Duration; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for DNS SRV records. + */ +class ProtocolDriverIntegrationTest { + + @ParameterizedTest + @ValueSource(strings = { + "r2dbc:mysql+srv:loadbalance://localhost:3306/r2dbc", + }) + void anyAvailable(String url) { + // localhost should be resolved to 127.0.0.1 and [::1], but I can't make sure GitHub Actions support IPv6 + MySqlConnectionFactory.from(MySqlConnectionFactoryProvider.setup(setupUrlAndCredentials(url))) + .create() + .flatMapMany(connection -> connection.validate(ValidationDepth.REMOTE) + .onErrorReturn(false) + .concatWith(connection.close().then(Mono.empty()))) + .as(StepVerifier::create) + .expectNext(true) + .verifyComplete(); + } + + private static ConnectionFactoryOptions setupUrlAndCredentials(String url) { + String password = System.getProperty("test.mysql.password"); + + assertThat(password).withFailMessage("Property test.mysql.password must exists and not be empty") + .isNotNull() + .isNotEmpty(); + + return ConnectionFactoryOptions.parse(url).mutate() + .option(ConnectionFactoryOptions.USER, "root") + .option(ConnectionFactoryOptions.PASSWORD, password) + .option(ConnectionFactoryOptions.CONNECT_TIMEOUT, Duration.ofSeconds(3)) + .build(); + } +} diff --git a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/TextParametrizedStatementTest.java b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/TextParametrizedStatementTest.java index 4e833369b..1175fe5db 100644 --- a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/TextParametrizedStatementTest.java +++ b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/TextParametrizedStatementTest.java @@ -20,14 +20,13 @@ import io.asyncer.r2dbc.mysql.codec.Codecs; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; /** * Unit tests for {@link TextParametrizedStatement}. */ class TextParametrizedStatementTest implements StatementTestSupport { - private final Client client = mock(Client.class); - private final Codecs codecs = Codecs.builder().build(); @Override @@ -37,11 +36,14 @@ public void fetchSize() { @Override public TextParametrizedStatement makeInstance(boolean isMariaDB, String sql, String ignored) { + Client client = mock(Client.class); + + when(client.getContext()).thenReturn(ConnectionContextTest.mock(isMariaDB)); + return new TextParametrizedStatement( client, codecs, - Query.parse(sql), - ConnectionContextTest.mock(isMariaDB) + Query.parse(sql) ); } diff --git a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/TextSimpleStatementTest.java b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/TextSimpleStatementTest.java index 43cb1025c..5c74543b0 100644 --- a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/TextSimpleStatementTest.java +++ b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/TextSimpleStatementTest.java @@ -20,14 +20,13 @@ import io.asyncer.r2dbc.mysql.codec.Codecs; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; /** * Unit tests for {@link TextSimpleStatement}. */ class TextSimpleStatementTest implements StatementTestSupport { - private final Client client = mock(Client.class); - private final Codecs codecs = mock(Codecs.class); @Override @@ -52,7 +51,11 @@ public void fetchSize() { @Override public TextSimpleStatement makeInstance(boolean isMariaDB, String ignored, String sql) { - return new TextSimpleStatement(client, codecs, ConnectionContextTest.mock(isMariaDB), sql); + Client client = mock(Client.class); + + when(client.getContext()).thenReturn(ConnectionContextTest.mock(isMariaDB)); + + return new TextSimpleStatement(client, codecs, sql); } @Override diff --git a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/TimeZoneIntegrationTest.java b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/TimeZoneIntegrationTest.java index 99da15e3c..3ad60648c 100644 --- a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/TimeZoneIntegrationTest.java +++ b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/TimeZoneIntegrationTest.java @@ -2,6 +2,7 @@ import com.zaxxer.hikari.HikariDataSource; import io.asyncer.r2dbc.mysql.api.MySqlResult; +import io.asyncer.r2dbc.mysql.internal.NodeAddress; import org.assertj.core.data.TemporalUnitOffset; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterEach; @@ -307,22 +308,26 @@ private static MySqlConnectionConfiguration configuration( return customizer.apply(builder).build(); } - private static JdbcTemplate jdbc(MySqlConnectionConfiguration config) { + private static JdbcTemplate jdbc(MySqlConnectionConfiguration configuration) { + TcpSocketConfiguration socket = (TcpSocketConfiguration) configuration.getSocket(); + NodeAddress address = socket.getFirstAddress(); + Credential credential = configuration.getCredential().blockOptional().orElseThrow(() -> + new IllegalStateException("Credential must be present")); HikariDataSource source = new HikariDataSource(); - source.setJdbcUrl(String.format("jdbc:mysql://%s:%d/%s", config.getDomain(), - config.getPort(), config.getDatabase())); - source.setUsername(config.getUser()); - source.setPassword(Optional.ofNullable(config.getPassword()) + source.setJdbcUrl(String.format("jdbc:mysql://%s:%d/%s", + address.getHost(), address.getPort(), configuration.getDatabase())); + source.setUsername(credential.getUser()); + source.setPassword(Optional.ofNullable(credential.getPassword()) .map(Object::toString).orElse(null)); source.setMaximumPoolSize(1); - source.setConnectionTimeout(Optional.ofNullable(config.getConnectTimeout()) + source.setConnectionTimeout(Optional.ofNullable(configuration.getClient().getConnectTimeout()) .map(Duration::toMillis).orElse(0L)); - source.addDataSourceProperty("preserveInstants", config.isPreserveInstants()); - source.addDataSourceProperty("connectionTimeZone", config.getConnectionTimeZone()); + source.addDataSourceProperty("preserveInstants", configuration.isPreserveInstants()); + source.addDataSourceProperty("connectionTimeZone", configuration.getConnectionTimeZone()); source.addDataSourceProperty("forceConnectionTimeZoneToSession", - config.isForceConnectionTimeZoneToSession()); + configuration.isForceConnectionTimeZoneToSession()); return new JdbcTemplate(source); } diff --git a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/internal/util/AddressUtilsTest.java b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/internal/util/AddressUtilsTest.java index da22bcc14..e9b1d4f64 100644 --- a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/internal/util/AddressUtilsTest.java +++ b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/internal/util/AddressUtilsTest.java @@ -16,76 +16,180 @@ package io.asyncer.r2dbc.mysql.internal.util; -import org.junit.jupiter.api.Test; +import io.asyncer.r2dbc.mysql.internal.NodeAddress; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertTrue; +import java.util.stream.Stream; + +import static org.assertj.core.api.Assertions.assertThat; /** * Unit tests for {@link AddressUtils}. */ class AddressUtilsTest { - @Test - void isIpv4() { - assertTrue(AddressUtils.isIpv4("1.0.0.0")); - assertTrue(AddressUtils.isIpv4("127.0.0.1")); - assertTrue(AddressUtils.isIpv4("10.11.12.13")); - assertTrue(AddressUtils.isIpv4("192.168.0.0")); - assertTrue(AddressUtils.isIpv4("255.255.255.255")); + @ParameterizedTest + @ValueSource(strings = { + "1.0.0.0", + "127.0.0.1", + "10.11.12.13", + "192.168.0.0", + "255.255.255.255", + }) + void isIpv4(String address) { + assertThat(AddressUtils.isIpv4(address)).isTrue(); + } - assertFalse(AddressUtils.isIpv4("0.0.0.0")); - assertFalse(AddressUtils.isIpv4(" 127.0.0.1 ")); - assertFalse(AddressUtils.isIpv4("01.11.12.13")); - assertFalse(AddressUtils.isIpv4("092.168.0.1")); - assertFalse(AddressUtils.isIpv4("055.255.255.255")); - assertFalse(AddressUtils.isIpv4("g.ar.ba.ge")); - assertFalse(AddressUtils.isIpv4("192.168.0")); - assertFalse(AddressUtils.isIpv4("192.168.0a.0")); - assertFalse(AddressUtils.isIpv4("256.255.255.255")); - assertFalse(AddressUtils.isIpv4("0.255.255.255")); + @ParameterizedTest + @ValueSource(strings = { + "0.0.0.0", + " 127.0.0.1 ", + "01.11.12.13", + "092.168.0.1", + "055.255.255.255", + "g.ar.ba.ge", + "192.168.0", + "192.168.0a.0", + "256.255.255.255", + "0.255.255.255", + "::", + "::1", + "0:0:0:0:0:0:0:0", + "0:0:0:0:0:0:0:1", + "2001:0acd:0000:0000:0000:0000:3939:21fe", + "2001:acd:0:0:0:0:3939:21fe", + "2001:0acd:0:0::3939:21fe", + "2001:0acd::3939:21fe", + "2001:acd::3939:21fe", + }) + void isNotIpv4(String address) { + assertThat(AddressUtils.isIpv4(address)).isFalse(); + } - assertFalse(AddressUtils.isIpv4("::")); - assertFalse(AddressUtils.isIpv4("::1")); - assertFalse(AddressUtils.isIpv4("0:0:0:0:0:0:0:0")); - assertFalse(AddressUtils.isIpv4("0:0:0:0:0:0:0:1")); - assertFalse(AddressUtils.isIpv4("2001:0acd:0000:0000:0000:0000:3939:21fe")); - assertFalse(AddressUtils.isIpv4("2001:acd:0:0:0:0:3939:21fe")); - assertFalse(AddressUtils.isIpv4("2001:0acd:0:0::3939:21fe")); - assertFalse(AddressUtils.isIpv4("2001:0acd::3939:21fe")); - assertFalse(AddressUtils.isIpv4("2001:acd::3939:21fe")); + @ParameterizedTest + @ValueSource(strings = { + "::", + "::1", + "0:0:0:0:0:0:0:0", + "0:0:0:0:0:0:0:1", + "2001:0acd:0000:0000:0000:0000:3939:21fe", + "2001:acd:0:0:0:0:3939:21fe", + "2001:0acd:0:0::3939:21fe", + "2001:0acd::3939:21fe", + "2001:acd::3939:21fe", + }) + void isIpv6(String address) { + assertThat(AddressUtils.isIpv6(address)).isTrue(); } - @Test - void isIpv6() { - assertTrue(AddressUtils.isIpv6("::")); - assertTrue(AddressUtils.isIpv6("::1")); - assertTrue(AddressUtils.isIpv6("0:0:0:0:0:0:0:0")); - assertTrue(AddressUtils.isIpv6("0:0:0:0:0:0:0:1")); - assertTrue(AddressUtils.isIpv6("2001:0acd:0000:0000:0000:0000:3939:21fe")); - assertTrue(AddressUtils.isIpv6("2001:acd:0:0:0:0:3939:21fe")); - assertTrue(AddressUtils.isIpv6("2001:0acd:0:0::3939:21fe")); - assertTrue(AddressUtils.isIpv6("2001:0acd::3939:21fe")); - assertTrue(AddressUtils.isIpv6("2001:acd::3939:21fe")); + @ParameterizedTest + @ValueSource(strings = { + "", + ":1", + "0:0:0:0:0:0:0", + "0:0:0:0:0:0:0:0:0", + "2001:0acd:0000:garb:age0:0000:3939:21fe", + "2001:0agd:0000:0000:0000:0000:3939:21fe", + "2001:0acd::0000::21fe", + "1:2:3:4:5:6:7::9", + "1::3:4:5:6:7:8:9", + "::3:4:5:6:7:8:9", + "1:2::4:5:6:7:8:9", + "1:2:3:4:5:6::8:9", + "0.0.0.0", + "1.0.0.0", + "127.0.0.1", + "10.11.12.13", + "192.168.0.0", + "255.255.255.255", + }) + void isNotIpv6(String address) { + assertThat(AddressUtils.isIpv6(address)).isFalse(); + } - assertFalse(AddressUtils.isIpv6("")); - assertFalse(AddressUtils.isIpv6(":1")); - assertFalse(AddressUtils.isIpv6("0:0:0:0:0:0:0")); - assertFalse(AddressUtils.isIpv6("0:0:0:0:0:0:0:0:0")); - assertFalse(AddressUtils.isIpv6("2001:0acd:0000:garb:age0:0000:3939:21fe")); - assertFalse(AddressUtils.isIpv6("2001:0agd:0000:0000:0000:0000:3939:21fe")); - assertFalse(AddressUtils.isIpv6("2001:0acd::0000::21fe")); - assertFalse(AddressUtils.isIpv6("1:2:3:4:5:6:7::9")); - assertFalse(AddressUtils.isIpv6("1::3:4:5:6:7:8:9")); - assertFalse(AddressUtils.isIpv6("::3:4:5:6:7:8:9")); - assertFalse(AddressUtils.isIpv6("1:2::4:5:6:7:8:9")); - assertFalse(AddressUtils.isIpv6("1:2:3:4:5:6::8:9")); + @ParameterizedTest + @MethodSource + void parseAddress(String host, NodeAddress except) { + assertThat(AddressUtils.parseAddress(host)).isEqualTo(except); + } - assertFalse(AddressUtils.isIpv6("0.0.0.0")); - assertFalse(AddressUtils.isIpv6("1.0.0.0")); - assertFalse(AddressUtils.isIpv6("127.0.0.1")); - assertFalse(AddressUtils.isIpv6("10.11.12.13")); - assertFalse(AddressUtils.isIpv6("192.168.0.0")); - assertFalse(AddressUtils.isIpv6("255.255.255.255")); + static Stream parseAddress() { + return Stream.of( + Arguments.of("localhost", new NodeAddress("localhost")), + Arguments.of("localhost:", new NodeAddress("localhost:")), + Arguments.of("localhost:1", new NodeAddress("localhost", 1)), + Arguments.of("localhost:3307", new NodeAddress("localhost", 3307)), + Arguments.of("localhost:65535", new NodeAddress("localhost", 65535)), + Arguments.of("localhost:65536", new NodeAddress("localhost:65536")), + Arguments.of("localhost:165536", new NodeAddress("localhost:165536")), + Arguments.of("a1234", new NodeAddress("a1234")), + Arguments.of(":1234", new NodeAddress(":1234")), + Arguments.of("[]:3305", new NodeAddress("[]:3305")), + Arguments.of("[::1]", new NodeAddress("[::1]")), + Arguments.of("[::1]:2", new NodeAddress("[::1]", 2)), + Arguments.of("[::1]:567", new NodeAddress("[::1]", 567)), + Arguments.of("[::1]:65535", new NodeAddress("[::1]", 65535)), + Arguments.of("[::1]:65536", new NodeAddress("[::1]:65536")), + Arguments.of("[::]", new NodeAddress("[::]")), + Arguments.of("[1::]", new NodeAddress("[1::]")), + Arguments.of("[::]:3", new NodeAddress("[::]", 3)), + Arguments.of("[::]:65536", new NodeAddress("[::]:65536")), + Arguments.of( + "[2001::2:3307]", + new NodeAddress("[2001::2:3307]") + ), + Arguments.of( + "[2001::2]:3307", + new NodeAddress("[2001::2]", 3307) + ), + Arguments.of( + "[a772:8380:7adf:77fd:4d58:d629:a237:0b5e]", + new NodeAddress("[a772:8380:7adf:77fd:4d58:d629:a237:0b5e]") + ), + Arguments.of( + "[ff19:7c3d:8ddb:c86c:647b:17d6:b64a:7930]:4", + new NodeAddress("[ff19:7c3d:8ddb:c86c:647b:17d6:b64a:7930]", 4) + ), + Arguments.of( + "[1234:fd2:5621:1:89::45]:567", + new NodeAddress("[1234:fd2:5621:1:89::45]", 567) + ), + Arguments.of( + "[2001:470:26:12b:9a65:b818:6c96:4271]:65535", + new NodeAddress("[2001:470:26:12b:9a65:b818:6c96:4271]", 65535) + ), + Arguments.of("168.10.0.9", new NodeAddress("168.10.0.9")), + Arguments.of("168.10.0.9:5", new NodeAddress("168.10.0.9", 5)), + Arguments.of("168.10.0.9:1234", new NodeAddress("168.10.0.9", 1234)), + Arguments.of("168.10.0.9:65535", new NodeAddress("168.10.0.9", 65535)), + // See also https://github.com/asyncer-io/r2dbc-mysql/issues/255 + Arguments.of("my_db", new NodeAddress("my_db")), + Arguments.of("my_db:6", new NodeAddress("my_db", 6)), + Arguments.of("my_db:3307", new NodeAddress("my_db", 3307)), + Arguments.of("my_db:65535", new NodeAddress("my_db", 65535)), + Arguments.of("db-service", new NodeAddress("db-service")), + Arguments.of("db-service:7", new NodeAddress("db-service", 7)), + Arguments.of("db-service:3307", new NodeAddress("db-service", 3307)), + Arguments.of("db-service:65535", new NodeAddress("db-service", 65535)), + Arguments.of( + "region_asia.rds3.some-cloud.com", + new NodeAddress("region_asia.rds3.some-cloud.com") + ), + Arguments.of( + "region_asia.rds4.some-cloud.com:8", + new NodeAddress("region_asia.rds4.some-cloud.com", 8) + ), + Arguments.of( + "region_asia.rds5.some-cloud.com:425", + new NodeAddress("region_asia.rds5.some-cloud.com", 425) + ), + Arguments.of( + "region_asia.rds6.some-cloud.com:65535", + new NodeAddress("region_asia.rds6.some-cloud.com", 65535) + ) + ); } }