rollbackTransactionToSavepoint(String name) {
- requireValidName(name, "Savepoint name must not be empty and not contain backticks");
+ requireNonEmpty(name, "Savepoint name must not be empty");
- return QueryFlow.executeVoid(client, String.format("ROLLBACK TO SAVEPOINT `%s`", name));
+ return QueryFlow.executeVoid(client, "ROLLBACK TO SAVEPOINT " + StringUtils.quoteIdentifier(name));
}
@Override
@@ -294,7 +294,7 @@ public MySqlConnectionMetadata getMetadata() {
* MySQL does not have any way to query the isolation level of the current transaction, only inferred from
* past statements, so driver can not make sure the result is right.
*
- * See https://bugs.mysql.com/bug.php?id=53341
+ * See MySQL Bug 53341
*
* {@inheritDoc}
*/
@@ -467,7 +467,7 @@ static Mono init(
return connection;
}
- requireValidName(database, "database must not be empty and not contain backticks");
+ requireNonEmpty(database, "database must not be empty");
return connection.flatMap(conn -> client.exchange(new InitDbMessage(database), INIT_DB)
.last()
@@ -476,7 +476,7 @@ static Mono init(
return Mono.just(conn);
}
- String sql = String.format("CREATE DATABASE IF NOT EXISTS `%s`", database);
+ String sql = "CREATE DATABASE IF NOT EXISTS " + StringUtils.quoteIdentifier(database);
return QueryFlow.executeVoid(client, sql)
.then(client.exchange(new InitDbMessage(database), INIT_DB_AFTER).then(Mono.just(conn)));
diff --git a/src/main/java/io/asyncer/r2dbc/mysql/MySqlStatementSupport.java b/src/main/java/io/asyncer/r2dbc/mysql/MySqlStatementSupport.java
index 4f67a085d..890914bde 100644
--- a/src/main/java/io/asyncer/r2dbc/mysql/MySqlStatementSupport.java
+++ b/src/main/java/io/asyncer/r2dbc/mysql/MySqlStatementSupport.java
@@ -16,12 +16,11 @@
package io.asyncer.r2dbc.mysql;
-
import org.jetbrains.annotations.Nullable;
import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.require;
+import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonEmpty;
import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonNull;
-import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireValidName;
/**
* Base class considers generic logic for {@link MySqlStatement} implementations.
@@ -42,8 +41,8 @@ public final MySqlStatement returnGeneratedValues(String... columns) {
this.generatedKeyName = LAST_INSERT_ID;
return this;
case 1:
- this.generatedKeyName = requireValidName(columns[0],
- "id name must not be empty and not contain backticks");
+ requireNonEmpty(columns[0], "id name must not be empty");
+ this.generatedKeyName = columns[0];
return this;
}
diff --git a/src/main/java/io/asyncer/r2dbc/mysql/QueryFlow.java b/src/main/java/io/asyncer/r2dbc/mysql/QueryFlow.java
index f6d56cf6c..203c8ed5c 100644
--- a/src/main/java/io/asyncer/r2dbc/mysql/QueryFlow.java
+++ b/src/main/java/io/asyncer/r2dbc/mysql/QueryFlow.java
@@ -23,6 +23,7 @@
import io.asyncer.r2dbc.mysql.constant.ServerStatuses;
import io.asyncer.r2dbc.mysql.constant.SslMode;
import io.asyncer.r2dbc.mysql.internal.util.InternalArrays;
+import io.asyncer.r2dbc.mysql.internal.util.StringUtils;
import io.asyncer.r2dbc.mysql.message.client.AuthResponse;
import io.asyncer.r2dbc.mysql.message.client.ClientMessage;
import io.asyncer.r2dbc.mysql.message.client.HandshakeResponse;
@@ -1230,7 +1231,7 @@ boolean cancelTasks() {
statements.add("BEGIN");
}
- final String doneSql = String.format("SAVEPOINT `%s`", name);
+ final String doneSql = "SAVEPOINT " + StringUtils.quoteIdentifier(name);
tasks |= CREATE_SAVEPOINT;
statements.add(doneSql);
return false;
diff --git a/src/main/java/io/asyncer/r2dbc/mysql/internal/util/AssertUtils.java b/src/main/java/io/asyncer/r2dbc/mysql/internal/util/AssertUtils.java
index 9aceb8570..246deac37 100644
--- a/src/main/java/io/asyncer/r2dbc/mysql/internal/util/AssertUtils.java
+++ b/src/main/java/io/asyncer/r2dbc/mysql/internal/util/AssertUtils.java
@@ -16,7 +16,6 @@
package io.asyncer.r2dbc.mysql.internal.util;
-
import org.jetbrains.annotations.Nullable;
/**
@@ -57,20 +56,17 @@ public static void require(boolean condition, String message) {
}
/**
- * Checks that a specified {@link String} is not {@code null} or empty or backticks included and throws a
- * customized {@link IllegalArgumentException} if it is.
+ * Checks that a {@link String} is neither {@code null} nor empty, and throws a customized
+ * {@link IllegalArgumentException} if it is.
*
- * @param name the {@link String} to check for nullity or empty or backticks included.
+ * @param s the string to check for empty.
* @param message the detail message to be used by thrown {@link IllegalArgumentException}.
- * @return {@code name} if not {@code null} or empty or backticks included.
- * @throws IllegalArgumentException if {@code name} is {@code null} or empty or backticks included.
+ * @throws IllegalArgumentException if {@code s} is {@code null} or empty.
*/
- public static String requireValidName(@Nullable String name, String message) {
- if (name == null || name.isEmpty() || name.indexOf('`') >= 0) {
+ public static void requireNonEmpty(@Nullable String s, String message) {
+ if (s == null || s.isEmpty()) {
throw new IllegalArgumentException(message);
}
-
- return name;
}
private AssertUtils() { }
diff --git a/src/main/java/io/asyncer/r2dbc/mysql/internal/util/StringUtils.java b/src/main/java/io/asyncer/r2dbc/mysql/internal/util/StringUtils.java
new file mode 100644
index 000000000..2ccbb55a5
--- /dev/null
+++ b/src/main/java/io/asyncer/r2dbc/mysql/internal/util/StringUtils.java
@@ -0,0 +1,58 @@
+/*
+ * Copyright 2024 asyncer.io projects
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.asyncer.r2dbc.mysql.internal.util;
+
+import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonEmpty;
+
+/**
+ * A utility for processing {@link String} in MySQL/MariaDB.
+ */
+public final class StringUtils {
+
+ private static final char QUOTE = '`';
+
+ public static String quoteIdentifier(String identifier) {
+ requireNonEmpty(identifier, "identifier must not be empty");
+
+ int index = identifier.indexOf(QUOTE);
+
+ if (index == -1) {
+ return QUOTE + identifier + QUOTE;
+ }
+
+ int len = identifier.length();
+ StringBuilder builder = new StringBuilder(len + 10).append(QUOTE);
+ int fromIndex = 0;
+
+ while (index != -1) {
+ builder.append(identifier, fromIndex, index)
+ .append(QUOTE)
+ .append(QUOTE);
+ fromIndex = index + 1;
+ index = identifier.indexOf(QUOTE, fromIndex);
+ }
+
+ if (fromIndex < len) {
+ builder.append(identifier, fromIndex, len);
+ }
+
+ return builder.append(QUOTE).toString();
+ }
+
+ private StringUtils() {
+ }
+}
diff --git a/src/test/java/io/asyncer/r2dbc/mysql/ConnectionIntegrationTest.java b/src/test/java/io/asyncer/r2dbc/mysql/ConnectionIntegrationTest.java
index 5063c29f3..075604031 100644
--- a/src/test/java/io/asyncer/r2dbc/mysql/ConnectionIntegrationTest.java
+++ b/src/test/java/io/asyncer/r2dbc/mysql/ConnectionIntegrationTest.java
@@ -17,6 +17,8 @@
package io.asyncer.r2dbc.mysql;
import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.ValueSource;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
@@ -189,48 +191,50 @@ void autoCommitStatusIsRestoredAfterTransaction() {
.doOnSuccess(ignored -> assertThat(connection.isAutoCommit()).isTrue()));
}
- @Test
- void createSavepointAndRollbackToSavepoint() {
+ @ParameterizedTest
+ @ValueSource(strings = { "test", "save`point" })
+ void createSavepointAndRollbackToSavepoint(String savepoint) {
complete(connection -> Mono.from(connection.createStatement(
- "CREATE TEMPORARY TABLE test (id INT NOT NULL PRIMARY KEY, name VARCHAR(50))").execute())
- .flatMap(IntegrationTestSupport::extractRowsUpdated)
- .then(connection.beginTransaction())
- .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue())
- .then(Mono.from(connection.createStatement("INSERT INTO test VALUES (1, 'test1')")
- .execute()))
- .flatMap(IntegrationTestSupport::extractRowsUpdated)
- .then(Mono.from(connection.createStatement("INSERT INTO test VALUES (2, 'test2')")
- .execute()))
- .flatMap(IntegrationTestSupport::extractRowsUpdated)
- .then(Mono.from(connection.createStatement("SELECT COUNT(*) FROM test").execute()))
- .flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class))))
- .doOnSuccess(count -> assertThat(count).isEqualTo(2))
- .then(connection.createSavepoint("test"))
- .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue())
- .then(Mono.from(connection.createStatement("INSERT INTO test VALUES (3, 'test3')")
- .execute()))
- .flatMap(IntegrationTestSupport::extractRowsUpdated)
- .then(Mono.from(connection.createStatement("INSERT INTO test VALUES (4, 'test4')")
- .execute()))
- .flatMap(IntegrationTestSupport::extractRowsUpdated)
- .then(Mono.from(connection.createStatement("SELECT COUNT(*) FROM test").execute()))
- .flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class))))
- .doOnSuccess(count -> assertThat(count).isEqualTo(4))
- .then(connection.rollbackTransactionToSavepoint("test"))
- .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue())
- .then(Mono.from(connection.createStatement("SELECT COUNT(*) FROM test").execute()))
- .flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class))))
- .doOnSuccess(count -> assertThat(count).isEqualTo(2))
- .then(connection.rollbackTransaction())
- .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isFalse())
- .then(Mono.from(connection.createStatement("SELECT COUNT(*) FROM test").execute()))
- .flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class))))
- .doOnSuccess(count -> assertThat(count).isEqualTo(0))
+ "CREATE TEMPORARY TABLE test (id INT NOT NULL PRIMARY KEY, name VARCHAR(50))").execute())
+ .flatMap(IntegrationTestSupport::extractRowsUpdated)
+ .then(connection.beginTransaction())
+ .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue())
+ .then(Mono.from(connection.createStatement("INSERT INTO test VALUES (1, 'test1')")
+ .execute()))
+ .flatMap(IntegrationTestSupport::extractRowsUpdated)
+ .then(Mono.from(connection.createStatement("INSERT INTO test VALUES (2, 'test2')")
+ .execute()))
+ .flatMap(IntegrationTestSupport::extractRowsUpdated)
+ .then(Mono.from(connection.createStatement("SELECT COUNT(*) FROM test").execute()))
+ .flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class))))
+ .doOnSuccess(count -> assertThat(count).isEqualTo(2))
+ .then(connection.createSavepoint(savepoint))
+ .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue())
+ .then(Mono.from(connection.createStatement("INSERT INTO test VALUES (3, 'test3')")
+ .execute()))
+ .flatMap(IntegrationTestSupport::extractRowsUpdated)
+ .then(Mono.from(connection.createStatement("INSERT INTO test VALUES (4, 'test4')")
+ .execute()))
+ .flatMap(IntegrationTestSupport::extractRowsUpdated)
+ .then(Mono.from(connection.createStatement("SELECT COUNT(*) FROM test").execute()))
+ .flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class))))
+ .doOnSuccess(count -> assertThat(count).isEqualTo(4))
+ .then(connection.rollbackTransactionToSavepoint(savepoint))
+ .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue())
+ .then(Mono.from(connection.createStatement("SELECT COUNT(*) FROM test").execute()))
+ .flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class))))
+ .doOnSuccess(count -> assertThat(count).isEqualTo(2))
+ .then(connection.rollbackTransaction())
+ .doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isFalse())
+ .then(Mono.from(connection.createStatement("SELECT COUNT(*) FROM test").execute()))
+ .flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class))))
+ .doOnSuccess(count -> assertThat(count).isEqualTo(0))
);
}
- @Test
- void createSavepointAndRollbackEntireTransaction() {
+ @ParameterizedTest
+ @ValueSource(strings = { "test", "save`point" })
+ void createSavepointAndRollbackEntireTransaction(String savepoint) {
complete(connection -> Mono.from(connection.createStatement(
"CREATE TEMPORARY TABLE test (id INT NOT NULL PRIMARY KEY, name VARCHAR(50))").execute())
.flatMap(IntegrationTestSupport::extractRowsUpdated)
@@ -245,7 +249,7 @@ void createSavepointAndRollbackEntireTransaction() {
.then(Mono.from(connection.createStatement("SELECT COUNT(*) FROM test").execute()))
.flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class))))
.doOnSuccess(count -> assertThat(count).isEqualTo(2))
- .then(connection.createSavepoint("test"))
+ .then(connection.createSavepoint(savepoint))
.doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue())
.then(Mono.from(connection.createStatement("INSERT INTO test VALUES (3, 'test3')")
.execute()))
diff --git a/src/test/java/io/asyncer/r2dbc/mysql/MySqlConnectionTest.java b/src/test/java/io/asyncer/r2dbc/mysql/MySqlConnectionTest.java
index ef39693b6..b82f1154d 100644
--- a/src/test/java/io/asyncer/r2dbc/mysql/MySqlConnectionTest.java
+++ b/src/test/java/io/asyncer/r2dbc/mysql/MySqlConnectionTest.java
@@ -98,9 +98,6 @@ void badCreateSavepoint() {
ThrowableTypeAssert> asserted = assertThatIllegalArgumentException();
asserted.isThrownBy(() -> noPrepare.createSavepoint(""));
- asserted.isThrownBy(() -> noPrepare.createSavepoint("`"));
- asserted.isThrownBy(() -> noPrepare.createSavepoint("name`"));
- asserted.isThrownBy(() -> noPrepare.createSavepoint("nam`e"));
asserted.isThrownBy(() -> noPrepare.createSavepoint(null));
}
@@ -110,9 +107,6 @@ void badReleaseSavepoint() {
ThrowableTypeAssert> asserted = assertThatIllegalArgumentException();
asserted.isThrownBy(() -> noPrepare.releaseSavepoint(""));
- asserted.isThrownBy(() -> noPrepare.releaseSavepoint("`"));
- asserted.isThrownBy(() -> noPrepare.releaseSavepoint("name`"));
- asserted.isThrownBy(() -> noPrepare.releaseSavepoint("nam`e"));
asserted.isThrownBy(() -> noPrepare.releaseSavepoint(null));
}
@@ -122,9 +116,6 @@ void badRollbackTransactionToSavepoint() {
ThrowableTypeAssert> asserted = assertThatIllegalArgumentException();
asserted.isThrownBy(() -> noPrepare.rollbackTransactionToSavepoint(""));
- asserted.isThrownBy(() -> noPrepare.rollbackTransactionToSavepoint("`"));
- asserted.isThrownBy(() -> noPrepare.rollbackTransactionToSavepoint("name`"));
- asserted.isThrownBy(() -> noPrepare.rollbackTransactionToSavepoint("nam`e"));
asserted.isThrownBy(() -> noPrepare.rollbackTransactionToSavepoint(null));
}
diff --git a/src/test/java/io/asyncer/r2dbc/mysql/StatementTestSupport.java b/src/test/java/io/asyncer/r2dbc/mysql/StatementTestSupport.java
index 0de394fba..65300af67 100644
--- a/src/test/java/io/asyncer/r2dbc/mysql/StatementTestSupport.java
+++ b/src/test/java/io/asyncer/r2dbc/mysql/StatementTestSupport.java
@@ -156,6 +156,8 @@ default void returnGeneratedValues() {
assertEquals(statement.generatedKeyName, "LAST_INSERT_ID");
statement.returnGeneratedValues("generated");
assertEquals(statement.generatedKeyName, "generated");
+ statement.returnGeneratedValues("generate`d");
+ assertEquals(statement.generatedKeyName, "generate`d");
}
@SuppressWarnings("ConstantConditions")
@@ -166,8 +168,6 @@ default void badReturnGeneratedValues() {
assertThrows(IllegalArgumentException.class, () -> statement.returnGeneratedValues((String) null));
assertThrows(IllegalArgumentException.class, () -> statement.returnGeneratedValues((String[]) null));
assertThrows(IllegalArgumentException.class, () -> statement.returnGeneratedValues(""));
- assertThrows(IllegalArgumentException.class, () -> statement.returnGeneratedValues("`generating`"));
- assertThrows(IllegalArgumentException.class, () -> statement.returnGeneratedValues("generating`"));
assertThrows(IllegalArgumentException.class, () ->
statement.returnGeneratedValues("generated", "names"));
}
diff --git a/src/test/java/io/asyncer/r2dbc/mysql/internal/util/StringUtilsTest.java b/src/test/java/io/asyncer/r2dbc/mysql/internal/util/StringUtilsTest.java
new file mode 100644
index 000000000..e8be1f442
--- /dev/null
+++ b/src/test/java/io/asyncer/r2dbc/mysql/internal/util/StringUtilsTest.java
@@ -0,0 +1,55 @@
+/*
+ * Copyright 2024 asyncer.io projects
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.asyncer.r2dbc.mysql.internal.util;
+
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.NullAndEmptySource;
+import org.junit.jupiter.params.provider.ValueSource;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
+
+/**
+ * Unit tests for {@link StringUtils}.
+ */
+class StringUtilsTest {
+
+ @ParameterizedTest
+ @ValueSource(strings = {
+ " ",
+ "`",
+ "``",
+ "bad`",
+ "`name",
+ "bad`name",
+ "`bad`name`",
+ "μ's",
+ " Reading books can be a \"great way\" to help you achieve more 'success' in `life` ",
+ "b%!@#()ar`\nfr-=321d`na``me",
+ "`b%!@#()ar`\nfr-=321d`na``me",
+ "`b%!@#()ar`\nfr-=321d`na``me``",
+ })
+ void quoteIdentifier(String name) {
+ assertThat(StringUtils.quoteIdentifier(name)).isEqualTo('`' + name.replaceAll("`", "``") + '`');
+ }
+
+ @ParameterizedTest
+ @NullAndEmptySource
+ void badQuoteIdentifier(String name) {
+ assertThatIllegalArgumentException().isThrownBy(() -> StringUtils.quoteIdentifier(name));
+ }
+}