diff --git a/pom.xml b/pom.xml index 76bf1542d..d24ccb032 100644 --- a/pom.xml +++ b/pom.xml @@ -179,6 +179,11 @@ junit-jupiter-engine test + + org.junit.jupiter + junit-jupiter-params + test + org.mockito mockito-core diff --git a/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnection.java b/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnection.java index ed8d04b42..b24f7edb3 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnection.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnection.java @@ -21,6 +21,7 @@ import io.asyncer.r2dbc.mysql.client.Client; import io.asyncer.r2dbc.mysql.codec.Codecs; import io.asyncer.r2dbc.mysql.constant.ServerStatuses; +import io.asyncer.r2dbc.mysql.internal.util.StringUtils; import io.asyncer.r2dbc.mysql.message.client.InitDbMessage; import io.asyncer.r2dbc.mysql.message.client.PingMessage; import io.asyncer.r2dbc.mysql.message.server.CompleteMessage; @@ -47,8 +48,8 @@ import java.util.function.Function; import java.util.function.Predicate; +import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonEmpty; import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonNull; -import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireValidName; /** * An implementation of {@link Connection} for connecting to the MySQL database. @@ -222,7 +223,8 @@ public MySqlBatch createBatch() { @Override public Mono createSavepoint(String name) { - requireValidName(name, "Savepoint name must not be empty and not contain backticks"); + requireNonEmpty(name, "Savepoint name must not be empty"); + return QueryFlow.createSavepoint(client, this, name, batchSupported); } @@ -266,23 +268,21 @@ public Mono preRelease() { @Override public Mono releaseSavepoint(String name) { - requireValidName(name, "Savepoint name must not be empty and not contain backticks"); + requireNonEmpty(name, "Savepoint name must not be empty"); - return QueryFlow.executeVoid(client, String.format("RELEASE SAVEPOINT `%s`", name)); + return QueryFlow.executeVoid(client, "RELEASE SAVEPOINT " + StringUtils.quoteIdentifier(name)); } @Override public Mono rollbackTransaction() { - return Mono.defer(() -> { - return QueryFlow.doneTransaction(client, this, false, batchSupported); - }); + return Mono.defer(() -> QueryFlow.doneTransaction(client, this, false, batchSupported)); } @Override public Mono rollbackTransactionToSavepoint(String name) { - requireValidName(name, "Savepoint name must not be empty and not contain backticks"); + requireNonEmpty(name, "Savepoint name must not be empty"); - return QueryFlow.executeVoid(client, String.format("ROLLBACK TO SAVEPOINT `%s`", name)); + return QueryFlow.executeVoid(client, "ROLLBACK TO SAVEPOINT " + StringUtils.quoteIdentifier(name)); } @Override @@ -294,7 +294,7 @@ public MySqlConnectionMetadata getMetadata() { * MySQL does not have any way to query the isolation level of the current transaction, only inferred from * past statements, so driver can not make sure the result is right. *

- * See https://bugs.mysql.com/bug.php?id=53341 + * See MySQL Bug 53341 *

* {@inheritDoc} */ @@ -467,7 +467,7 @@ static Mono init( return connection; } - requireValidName(database, "database must not be empty and not contain backticks"); + requireNonEmpty(database, "database must not be empty"); return connection.flatMap(conn -> client.exchange(new InitDbMessage(database), INIT_DB) .last() @@ -476,7 +476,7 @@ static Mono init( return Mono.just(conn); } - String sql = String.format("CREATE DATABASE IF NOT EXISTS `%s`", database); + String sql = "CREATE DATABASE IF NOT EXISTS " + StringUtils.quoteIdentifier(database); return QueryFlow.executeVoid(client, sql) .then(client.exchange(new InitDbMessage(database), INIT_DB_AFTER).then(Mono.just(conn))); diff --git a/src/main/java/io/asyncer/r2dbc/mysql/MySqlStatementSupport.java b/src/main/java/io/asyncer/r2dbc/mysql/MySqlStatementSupport.java index 4f67a085d..890914bde 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/MySqlStatementSupport.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/MySqlStatementSupport.java @@ -16,12 +16,11 @@ package io.asyncer.r2dbc.mysql; - import org.jetbrains.annotations.Nullable; import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.require; +import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonEmpty; import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonNull; -import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireValidName; /** * Base class considers generic logic for {@link MySqlStatement} implementations. @@ -42,8 +41,8 @@ public final MySqlStatement returnGeneratedValues(String... columns) { this.generatedKeyName = LAST_INSERT_ID; return this; case 1: - this.generatedKeyName = requireValidName(columns[0], - "id name must not be empty and not contain backticks"); + requireNonEmpty(columns[0], "id name must not be empty"); + this.generatedKeyName = columns[0]; return this; } diff --git a/src/main/java/io/asyncer/r2dbc/mysql/QueryFlow.java b/src/main/java/io/asyncer/r2dbc/mysql/QueryFlow.java index f6d56cf6c..203c8ed5c 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/QueryFlow.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/QueryFlow.java @@ -23,6 +23,7 @@ import io.asyncer.r2dbc.mysql.constant.ServerStatuses; import io.asyncer.r2dbc.mysql.constant.SslMode; import io.asyncer.r2dbc.mysql.internal.util.InternalArrays; +import io.asyncer.r2dbc.mysql.internal.util.StringUtils; import io.asyncer.r2dbc.mysql.message.client.AuthResponse; import io.asyncer.r2dbc.mysql.message.client.ClientMessage; import io.asyncer.r2dbc.mysql.message.client.HandshakeResponse; @@ -1230,7 +1231,7 @@ boolean cancelTasks() { statements.add("BEGIN"); } - final String doneSql = String.format("SAVEPOINT `%s`", name); + final String doneSql = "SAVEPOINT " + StringUtils.quoteIdentifier(name); tasks |= CREATE_SAVEPOINT; statements.add(doneSql); return false; diff --git a/src/main/java/io/asyncer/r2dbc/mysql/internal/util/AssertUtils.java b/src/main/java/io/asyncer/r2dbc/mysql/internal/util/AssertUtils.java index 9aceb8570..246deac37 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/internal/util/AssertUtils.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/internal/util/AssertUtils.java @@ -16,7 +16,6 @@ package io.asyncer.r2dbc.mysql.internal.util; - import org.jetbrains.annotations.Nullable; /** @@ -57,20 +56,17 @@ public static void require(boolean condition, String message) { } /** - * Checks that a specified {@link String} is not {@code null} or empty or backticks included and throws a - * customized {@link IllegalArgumentException} if it is. + * Checks that a {@link String} is neither {@code null} nor empty, and throws a customized + * {@link IllegalArgumentException} if it is. * - * @param name the {@link String} to check for nullity or empty or backticks included. + * @param s the string to check for empty. * @param message the detail message to be used by thrown {@link IllegalArgumentException}. - * @return {@code name} if not {@code null} or empty or backticks included. - * @throws IllegalArgumentException if {@code name} is {@code null} or empty or backticks included. + * @throws IllegalArgumentException if {@code s} is {@code null} or empty. */ - public static String requireValidName(@Nullable String name, String message) { - if (name == null || name.isEmpty() || name.indexOf('`') >= 0) { + public static void requireNonEmpty(@Nullable String s, String message) { + if (s == null || s.isEmpty()) { throw new IllegalArgumentException(message); } - - return name; } private AssertUtils() { } diff --git a/src/main/java/io/asyncer/r2dbc/mysql/internal/util/StringUtils.java b/src/main/java/io/asyncer/r2dbc/mysql/internal/util/StringUtils.java new file mode 100644 index 000000000..2ccbb55a5 --- /dev/null +++ b/src/main/java/io/asyncer/r2dbc/mysql/internal/util/StringUtils.java @@ -0,0 +1,58 @@ +/* + * Copyright 2024 asyncer.io projects + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.asyncer.r2dbc.mysql.internal.util; + +import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonEmpty; + +/** + * A utility for processing {@link String} in MySQL/MariaDB. + */ +public final class StringUtils { + + private static final char QUOTE = '`'; + + public static String quoteIdentifier(String identifier) { + requireNonEmpty(identifier, "identifier must not be empty"); + + int index = identifier.indexOf(QUOTE); + + if (index == -1) { + return QUOTE + identifier + QUOTE; + } + + int len = identifier.length(); + StringBuilder builder = new StringBuilder(len + 10).append(QUOTE); + int fromIndex = 0; + + while (index != -1) { + builder.append(identifier, fromIndex, index) + .append(QUOTE) + .append(QUOTE); + fromIndex = index + 1; + index = identifier.indexOf(QUOTE, fromIndex); + } + + if (fromIndex < len) { + builder.append(identifier, fromIndex, len); + } + + return builder.append(QUOTE).toString(); + } + + private StringUtils() { + } +} diff --git a/src/test/java/io/asyncer/r2dbc/mysql/ConnectionIntegrationTest.java b/src/test/java/io/asyncer/r2dbc/mysql/ConnectionIntegrationTest.java index 5063c29f3..075604031 100644 --- a/src/test/java/io/asyncer/r2dbc/mysql/ConnectionIntegrationTest.java +++ b/src/test/java/io/asyncer/r2dbc/mysql/ConnectionIntegrationTest.java @@ -17,6 +17,8 @@ package io.asyncer.r2dbc.mysql; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -189,48 +191,50 @@ void autoCommitStatusIsRestoredAfterTransaction() { .doOnSuccess(ignored -> assertThat(connection.isAutoCommit()).isTrue())); } - @Test - void createSavepointAndRollbackToSavepoint() { + @ParameterizedTest + @ValueSource(strings = { "test", "save`point" }) + void createSavepointAndRollbackToSavepoint(String savepoint) { complete(connection -> Mono.from(connection.createStatement( - "CREATE TEMPORARY TABLE test (id INT NOT NULL PRIMARY KEY, name VARCHAR(50))").execute()) - .flatMap(IntegrationTestSupport::extractRowsUpdated) - .then(connection.beginTransaction()) - .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue()) - .then(Mono.from(connection.createStatement("INSERT INTO test VALUES (1, 'test1')") - .execute())) - .flatMap(IntegrationTestSupport::extractRowsUpdated) - .then(Mono.from(connection.createStatement("INSERT INTO test VALUES (2, 'test2')") - .execute())) - .flatMap(IntegrationTestSupport::extractRowsUpdated) - .then(Mono.from(connection.createStatement("SELECT COUNT(*) FROM test").execute())) - .flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class)))) - .doOnSuccess(count -> assertThat(count).isEqualTo(2)) - .then(connection.createSavepoint("test")) - .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue()) - .then(Mono.from(connection.createStatement("INSERT INTO test VALUES (3, 'test3')") - .execute())) - .flatMap(IntegrationTestSupport::extractRowsUpdated) - .then(Mono.from(connection.createStatement("INSERT INTO test VALUES (4, 'test4')") - .execute())) - .flatMap(IntegrationTestSupport::extractRowsUpdated) - .then(Mono.from(connection.createStatement("SELECT COUNT(*) FROM test").execute())) - .flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class)))) - .doOnSuccess(count -> assertThat(count).isEqualTo(4)) - .then(connection.rollbackTransactionToSavepoint("test")) - .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue()) - .then(Mono.from(connection.createStatement("SELECT COUNT(*) FROM test").execute())) - .flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class)))) - .doOnSuccess(count -> assertThat(count).isEqualTo(2)) - .then(connection.rollbackTransaction()) - .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isFalse()) - .then(Mono.from(connection.createStatement("SELECT COUNT(*) FROM test").execute())) - .flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class)))) - .doOnSuccess(count -> assertThat(count).isEqualTo(0)) + "CREATE TEMPORARY TABLE test (id INT NOT NULL PRIMARY KEY, name VARCHAR(50))").execute()) + .flatMap(IntegrationTestSupport::extractRowsUpdated) + .then(connection.beginTransaction()) + .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue()) + .then(Mono.from(connection.createStatement("INSERT INTO test VALUES (1, 'test1')") + .execute())) + .flatMap(IntegrationTestSupport::extractRowsUpdated) + .then(Mono.from(connection.createStatement("INSERT INTO test VALUES (2, 'test2')") + .execute())) + .flatMap(IntegrationTestSupport::extractRowsUpdated) + .then(Mono.from(connection.createStatement("SELECT COUNT(*) FROM test").execute())) + .flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class)))) + .doOnSuccess(count -> assertThat(count).isEqualTo(2)) + .then(connection.createSavepoint(savepoint)) + .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue()) + .then(Mono.from(connection.createStatement("INSERT INTO test VALUES (3, 'test3')") + .execute())) + .flatMap(IntegrationTestSupport::extractRowsUpdated) + .then(Mono.from(connection.createStatement("INSERT INTO test VALUES (4, 'test4')") + .execute())) + .flatMap(IntegrationTestSupport::extractRowsUpdated) + .then(Mono.from(connection.createStatement("SELECT COUNT(*) FROM test").execute())) + .flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class)))) + .doOnSuccess(count -> assertThat(count).isEqualTo(4)) + .then(connection.rollbackTransactionToSavepoint(savepoint)) + .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue()) + .then(Mono.from(connection.createStatement("SELECT COUNT(*) FROM test").execute())) + .flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class)))) + .doOnSuccess(count -> assertThat(count).isEqualTo(2)) + .then(connection.rollbackTransaction()) + .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isFalse()) + .then(Mono.from(connection.createStatement("SELECT COUNT(*) FROM test").execute())) + .flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class)))) + .doOnSuccess(count -> assertThat(count).isEqualTo(0)) ); } - @Test - void createSavepointAndRollbackEntireTransaction() { + @ParameterizedTest + @ValueSource(strings = { "test", "save`point" }) + void createSavepointAndRollbackEntireTransaction(String savepoint) { complete(connection -> Mono.from(connection.createStatement( "CREATE TEMPORARY TABLE test (id INT NOT NULL PRIMARY KEY, name VARCHAR(50))").execute()) .flatMap(IntegrationTestSupport::extractRowsUpdated) @@ -245,7 +249,7 @@ void createSavepointAndRollbackEntireTransaction() { .then(Mono.from(connection.createStatement("SELECT COUNT(*) FROM test").execute())) .flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class)))) .doOnSuccess(count -> assertThat(count).isEqualTo(2)) - .then(connection.createSavepoint("test")) + .then(connection.createSavepoint(savepoint)) .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue()) .then(Mono.from(connection.createStatement("INSERT INTO test VALUES (3, 'test3')") .execute())) diff --git a/src/test/java/io/asyncer/r2dbc/mysql/MySqlConnectionTest.java b/src/test/java/io/asyncer/r2dbc/mysql/MySqlConnectionTest.java index ef39693b6..b82f1154d 100644 --- a/src/test/java/io/asyncer/r2dbc/mysql/MySqlConnectionTest.java +++ b/src/test/java/io/asyncer/r2dbc/mysql/MySqlConnectionTest.java @@ -98,9 +98,6 @@ void badCreateSavepoint() { ThrowableTypeAssert asserted = assertThatIllegalArgumentException(); asserted.isThrownBy(() -> noPrepare.createSavepoint("")); - asserted.isThrownBy(() -> noPrepare.createSavepoint("`")); - asserted.isThrownBy(() -> noPrepare.createSavepoint("name`")); - asserted.isThrownBy(() -> noPrepare.createSavepoint("nam`e")); asserted.isThrownBy(() -> noPrepare.createSavepoint(null)); } @@ -110,9 +107,6 @@ void badReleaseSavepoint() { ThrowableTypeAssert asserted = assertThatIllegalArgumentException(); asserted.isThrownBy(() -> noPrepare.releaseSavepoint("")); - asserted.isThrownBy(() -> noPrepare.releaseSavepoint("`")); - asserted.isThrownBy(() -> noPrepare.releaseSavepoint("name`")); - asserted.isThrownBy(() -> noPrepare.releaseSavepoint("nam`e")); asserted.isThrownBy(() -> noPrepare.releaseSavepoint(null)); } @@ -122,9 +116,6 @@ void badRollbackTransactionToSavepoint() { ThrowableTypeAssert asserted = assertThatIllegalArgumentException(); asserted.isThrownBy(() -> noPrepare.rollbackTransactionToSavepoint("")); - asserted.isThrownBy(() -> noPrepare.rollbackTransactionToSavepoint("`")); - asserted.isThrownBy(() -> noPrepare.rollbackTransactionToSavepoint("name`")); - asserted.isThrownBy(() -> noPrepare.rollbackTransactionToSavepoint("nam`e")); asserted.isThrownBy(() -> noPrepare.rollbackTransactionToSavepoint(null)); } diff --git a/src/test/java/io/asyncer/r2dbc/mysql/StatementTestSupport.java b/src/test/java/io/asyncer/r2dbc/mysql/StatementTestSupport.java index 0de394fba..65300af67 100644 --- a/src/test/java/io/asyncer/r2dbc/mysql/StatementTestSupport.java +++ b/src/test/java/io/asyncer/r2dbc/mysql/StatementTestSupport.java @@ -156,6 +156,8 @@ default void returnGeneratedValues() { assertEquals(statement.generatedKeyName, "LAST_INSERT_ID"); statement.returnGeneratedValues("generated"); assertEquals(statement.generatedKeyName, "generated"); + statement.returnGeneratedValues("generate`d"); + assertEquals(statement.generatedKeyName, "generate`d"); } @SuppressWarnings("ConstantConditions") @@ -166,8 +168,6 @@ default void badReturnGeneratedValues() { assertThrows(IllegalArgumentException.class, () -> statement.returnGeneratedValues((String) null)); assertThrows(IllegalArgumentException.class, () -> statement.returnGeneratedValues((String[]) null)); assertThrows(IllegalArgumentException.class, () -> statement.returnGeneratedValues("")); - assertThrows(IllegalArgumentException.class, () -> statement.returnGeneratedValues("`generating`")); - assertThrows(IllegalArgumentException.class, () -> statement.returnGeneratedValues("generating`")); assertThrows(IllegalArgumentException.class, () -> statement.returnGeneratedValues("generated", "names")); } diff --git a/src/test/java/io/asyncer/r2dbc/mysql/internal/util/StringUtilsTest.java b/src/test/java/io/asyncer/r2dbc/mysql/internal/util/StringUtilsTest.java new file mode 100644 index 000000000..e8be1f442 --- /dev/null +++ b/src/test/java/io/asyncer/r2dbc/mysql/internal/util/StringUtilsTest.java @@ -0,0 +1,55 @@ +/* + * Copyright 2024 asyncer.io projects + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.asyncer.r2dbc.mysql.internal.util; + +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.NullAndEmptySource; +import org.junit.jupiter.params.provider.ValueSource; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; + +/** + * Unit tests for {@link StringUtils}. + */ +class StringUtilsTest { + + @ParameterizedTest + @ValueSource(strings = { + " ", + "`", + "``", + "bad`", + "`name", + "bad`name", + "`bad`name`", + "μ's", + " Reading books can be a \"great way\" to help you achieve more 'success' in `life` ", + "b%!@#()ar`\nfr-=321d`na``me", + "`b%!@#()ar`\nfr-=321d`na``me", + "`b%!@#()ar`\nfr-=321d`na``me``", + }) + void quoteIdentifier(String name) { + assertThat(StringUtils.quoteIdentifier(name)).isEqualTo('`' + name.replaceAll("`", "``") + '`'); + } + + @ParameterizedTest + @NullAndEmptySource + void badQuoteIdentifier(String name) { + assertThatIllegalArgumentException().isThrownBy(() -> StringUtils.quoteIdentifier(name)); + } +}