Skip to content

Commit

Permalink
Add quote identifier support
Browse files Browse the repository at this point in the history
  • Loading branch information
mirromutth committed Jan 15, 2024
1 parent 72f824c commit e01335f
Show file tree
Hide file tree
Showing 10 changed files with 186 additions and 77 deletions.
5 changes: 5 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,11 @@
<artifactId>junit-jupiter-engine</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-params</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
Expand Down
24 changes: 12 additions & 12 deletions src/main/java/io/asyncer/r2dbc/mysql/MySqlConnection.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.asyncer.r2dbc.mysql.client.Client;
import io.asyncer.r2dbc.mysql.codec.Codecs;
import io.asyncer.r2dbc.mysql.constant.ServerStatuses;
import io.asyncer.r2dbc.mysql.internal.util.StringUtils;
import io.asyncer.r2dbc.mysql.message.client.InitDbMessage;
import io.asyncer.r2dbc.mysql.message.client.PingMessage;
import io.asyncer.r2dbc.mysql.message.server.CompleteMessage;
Expand All @@ -47,8 +48,8 @@
import java.util.function.Function;
import java.util.function.Predicate;

import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonEmpty;
import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonNull;
import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireValidName;

/**
* An implementation of {@link Connection} for connecting to the MySQL database.
Expand Down Expand Up @@ -222,7 +223,8 @@ public MySqlBatch createBatch() {

@Override
public Mono<Void> createSavepoint(String name) {
requireValidName(name, "Savepoint name must not be empty and not contain backticks");
requireNonEmpty(name, "Savepoint name must not be empty");

return QueryFlow.createSavepoint(client, this, name, batchSupported);
}

Expand Down Expand Up @@ -266,23 +268,21 @@ public Mono<Void> preRelease() {

@Override
public Mono<Void> releaseSavepoint(String name) {
requireValidName(name, "Savepoint name must not be empty and not contain backticks");
requireNonEmpty(name, "Savepoint name must not be empty");

return QueryFlow.executeVoid(client, String.format("RELEASE SAVEPOINT `%s`", name));
return QueryFlow.executeVoid(client, "RELEASE SAVEPOINT " + StringUtils.quoteIdentifier(name));
}

@Override
public Mono<Void> rollbackTransaction() {
return Mono.defer(() -> {
return QueryFlow.doneTransaction(client, this, false, batchSupported);
});
return Mono.defer(() -> QueryFlow.doneTransaction(client, this, false, batchSupported));
}

@Override
public Mono<Void> rollbackTransactionToSavepoint(String name) {
requireValidName(name, "Savepoint name must not be empty and not contain backticks");
requireNonEmpty(name, "Savepoint name must not be empty");

return QueryFlow.executeVoid(client, String.format("ROLLBACK TO SAVEPOINT `%s`", name));
return QueryFlow.executeVoid(client, "ROLLBACK TO SAVEPOINT " + StringUtils.quoteIdentifier(name));
}

@Override
Expand All @@ -294,7 +294,7 @@ public MySqlConnectionMetadata getMetadata() {
* MySQL does not have any way to query the isolation level of the current transaction, only inferred from
* past statements, so driver can not make sure the result is right.
* <p>
* See https://bugs.mysql.com/bug.php?id=53341
* See <a href="https://bugs.mysql.com/bug.php?id=53341">MySQL Bug 53341</a>
* <p>
* {@inheritDoc}
*/
Expand Down Expand Up @@ -467,7 +467,7 @@ static Mono<MySqlConnection> init(
return connection;
}

requireValidName(database, "database must not be empty and not contain backticks");
requireNonEmpty(database, "database must not be empty");

return connection.flatMap(conn -> client.exchange(new InitDbMessage(database), INIT_DB)
.last()
Expand All @@ -476,7 +476,7 @@ static Mono<MySqlConnection> init(
return Mono.just(conn);
}

String sql = String.format("CREATE DATABASE IF NOT EXISTS `%s`", database);
String sql = "CREATE DATABASE IF NOT EXISTS " + StringUtils.quoteIdentifier(database);

return QueryFlow.executeVoid(client, sql)
.then(client.exchange(new InitDbMessage(database), INIT_DB_AFTER).then(Mono.just(conn)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@

package io.asyncer.r2dbc.mysql;


import org.jetbrains.annotations.Nullable;

import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.require;
import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonEmpty;
import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonNull;
import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireValidName;

/**
* Base class considers generic logic for {@link MySqlStatement} implementations.
Expand All @@ -42,8 +41,8 @@ public final MySqlStatement returnGeneratedValues(String... columns) {
this.generatedKeyName = LAST_INSERT_ID;
return this;
case 1:
this.generatedKeyName = requireValidName(columns[0],
"id name must not be empty and not contain backticks");
requireNonEmpty(columns[0], "id name must not be empty");
this.generatedKeyName = columns[0];
return this;
}

Expand Down
3 changes: 2 additions & 1 deletion src/main/java/io/asyncer/r2dbc/mysql/QueryFlow.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io.asyncer.r2dbc.mysql.constant.ServerStatuses;
import io.asyncer.r2dbc.mysql.constant.SslMode;
import io.asyncer.r2dbc.mysql.internal.util.InternalArrays;
import io.asyncer.r2dbc.mysql.internal.util.StringUtils;
import io.asyncer.r2dbc.mysql.message.client.AuthResponse;
import io.asyncer.r2dbc.mysql.message.client.ClientMessage;
import io.asyncer.r2dbc.mysql.message.client.HandshakeResponse;
Expand Down Expand Up @@ -1230,7 +1231,7 @@ boolean cancelTasks() {
statements.add("BEGIN");
}

final String doneSql = String.format("SAVEPOINT `%s`", name);
final String doneSql = "SAVEPOINT " + StringUtils.quoteIdentifier(name);
tasks |= CREATE_SAVEPOINT;
statements.add(doneSql);
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

package io.asyncer.r2dbc.mysql.internal.util;


import org.jetbrains.annotations.Nullable;

/**
Expand Down Expand Up @@ -57,20 +56,17 @@ public static void require(boolean condition, String message) {
}

/**
* Checks that a specified {@link String} is not {@code null} or empty or backticks included and throws a
* customized {@link IllegalArgumentException} if it is.
* Checks that a {@link String} is neither {@code null} nor empty, and throws a customized
* {@link IllegalArgumentException} if it is.
*
* @param name the {@link String} to check for nullity or empty or backticks included.
* @param s the string to check for empty.
* @param message the detail message to be used by thrown {@link IllegalArgumentException}.
* @return {@code name} if not {@code null} or empty or backticks included.
* @throws IllegalArgumentException if {@code name} is {@code null} or empty or backticks included.
* @throws IllegalArgumentException if {@code s} is {@code null} or empty.
*/
public static String requireValidName(@Nullable String name, String message) {
if (name == null || name.isEmpty() || name.indexOf('`') >= 0) {
public static void requireNonEmpty(@Nullable String s, String message) {
if (s == null || s.isEmpty()) {
throw new IllegalArgumentException(message);
}

return name;
}

private AssertUtils() { }
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright 2024 asyncer.io projects
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.asyncer.r2dbc.mysql.internal.util;

import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonEmpty;

/**
* A utility for processing {@link String} in MySQL/MariaDB.
*/
public final class StringUtils {

private static final char QUOTE = '`';

public static String quoteIdentifier(String identifier) {
requireNonEmpty(identifier, "identifier must not be empty");

int index = identifier.indexOf(QUOTE);

if (index == -1) {
return QUOTE + identifier + QUOTE;
}

int len = identifier.length();
StringBuilder builder = new StringBuilder(len + 10).append(QUOTE);
int fromIndex = 0;

while (index != -1) {
builder.append(identifier, fromIndex, index)
.append(QUOTE)
.append(QUOTE);
fromIndex = index + 1;
index = identifier.indexOf(QUOTE, fromIndex);
}

if (fromIndex < len) {
builder.append(identifier, fromIndex, len);
}

return builder.append(QUOTE).toString();
}

private StringUtils() {
}
}
82 changes: 43 additions & 39 deletions src/test/java/io/asyncer/r2dbc/mysql/ConnectionIntegrationTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package io.asyncer.r2dbc.mysql;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

Expand Down Expand Up @@ -189,48 +191,50 @@ void autoCommitStatusIsRestoredAfterTransaction() {
.doOnSuccess(ignored -> assertThat(connection.isAutoCommit()).isTrue()));
}

@Test
void createSavepointAndRollbackToSavepoint() {
@ParameterizedTest
@ValueSource(strings = { "test", "save`point" })
void createSavepointAndRollbackToSavepoint(String savepoint) {
complete(connection -> Mono.from(connection.createStatement(
"CREATE TEMPORARY TABLE test (id INT NOT NULL PRIMARY KEY, name VARCHAR(50))").execute())
.flatMap(IntegrationTestSupport::extractRowsUpdated)
.then(connection.beginTransaction())
.doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue())
.then(Mono.from(connection.createStatement("INSERT INTO test VALUES (1, 'test1')")
.execute()))
.flatMap(IntegrationTestSupport::extractRowsUpdated)
.then(Mono.from(connection.createStatement("INSERT INTO test VALUES (2, 'test2')")
.execute()))
.flatMap(IntegrationTestSupport::extractRowsUpdated)
.then(Mono.from(connection.createStatement("SELECT COUNT(*) FROM test").execute()))
.flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class))))
.doOnSuccess(count -> assertThat(count).isEqualTo(2))
.then(connection.createSavepoint("test"))
.doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue())
.then(Mono.from(connection.createStatement("INSERT INTO test VALUES (3, 'test3')")
.execute()))
.flatMap(IntegrationTestSupport::extractRowsUpdated)
.then(Mono.from(connection.createStatement("INSERT INTO test VALUES (4, 'test4')")
.execute()))
.flatMap(IntegrationTestSupport::extractRowsUpdated)
.then(Mono.from(connection.createStatement("SELECT COUNT(*) FROM test").execute()))
.flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class))))
.doOnSuccess(count -> assertThat(count).isEqualTo(4))
.then(connection.rollbackTransactionToSavepoint("test"))
.doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue())
.then(Mono.from(connection.createStatement("SELECT COUNT(*) FROM test").execute()))
.flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class))))
.doOnSuccess(count -> assertThat(count).isEqualTo(2))
.then(connection.rollbackTransaction())
.doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isFalse())
.then(Mono.from(connection.createStatement("SELECT COUNT(*) FROM test").execute()))
.flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class))))
.doOnSuccess(count -> assertThat(count).isEqualTo(0))
"CREATE TEMPORARY TABLE test (id INT NOT NULL PRIMARY KEY, name VARCHAR(50))").execute())
.flatMap(IntegrationTestSupport::extractRowsUpdated)
.then(connection.beginTransaction())
.doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue())
.then(Mono.from(connection.createStatement("INSERT INTO test VALUES (1, 'test1')")
.execute()))
.flatMap(IntegrationTestSupport::extractRowsUpdated)
.then(Mono.from(connection.createStatement("INSERT INTO test VALUES (2, 'test2')")
.execute()))
.flatMap(IntegrationTestSupport::extractRowsUpdated)
.then(Mono.from(connection.createStatement("SELECT COUNT(*) FROM test").execute()))
.flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class))))
.doOnSuccess(count -> assertThat(count).isEqualTo(2))
.then(connection.createSavepoint(savepoint))
.doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue())
.then(Mono.from(connection.createStatement("INSERT INTO test VALUES (3, 'test3')")
.execute()))
.flatMap(IntegrationTestSupport::extractRowsUpdated)
.then(Mono.from(connection.createStatement("INSERT INTO test VALUES (4, 'test4')")
.execute()))
.flatMap(IntegrationTestSupport::extractRowsUpdated)
.then(Mono.from(connection.createStatement("SELECT COUNT(*) FROM test").execute()))
.flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class))))
.doOnSuccess(count -> assertThat(count).isEqualTo(4))
.then(connection.rollbackTransactionToSavepoint(savepoint))
.doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue())
.then(Mono.from(connection.createStatement("SELECT COUNT(*) FROM test").execute()))
.flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class))))
.doOnSuccess(count -> assertThat(count).isEqualTo(2))
.then(connection.rollbackTransaction())
.doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isFalse())
.then(Mono.from(connection.createStatement("SELECT COUNT(*) FROM test").execute()))
.flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class))))
.doOnSuccess(count -> assertThat(count).isEqualTo(0))
);
}

@Test
void createSavepointAndRollbackEntireTransaction() {
@ParameterizedTest
@ValueSource(strings = { "test", "save`point" })
void createSavepointAndRollbackEntireTransaction(String savepoint) {
complete(connection -> Mono.from(connection.createStatement(
"CREATE TEMPORARY TABLE test (id INT NOT NULL PRIMARY KEY, name VARCHAR(50))").execute())
.flatMap(IntegrationTestSupport::extractRowsUpdated)
Expand All @@ -245,7 +249,7 @@ void createSavepointAndRollbackEntireTransaction() {
.then(Mono.from(connection.createStatement("SELECT COUNT(*) FROM test").execute()))
.flatMap(result -> Mono.from(result.map((row, metadata) -> row.get(0, Long.class))))
.doOnSuccess(count -> assertThat(count).isEqualTo(2))
.then(connection.createSavepoint("test"))
.then(connection.createSavepoint(savepoint))
.doOnSuccess(ignored -> assertThat(connection.isInTransaction()).isTrue())
.then(Mono.from(connection.createStatement("INSERT INTO test VALUES (3, 'test3')")
.execute()))
Expand Down
9 changes: 0 additions & 9 deletions src/test/java/io/asyncer/r2dbc/mysql/MySqlConnectionTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,6 @@ void badCreateSavepoint() {
ThrowableTypeAssert<?> asserted = assertThatIllegalArgumentException();

asserted.isThrownBy(() -> noPrepare.createSavepoint(""));
asserted.isThrownBy(() -> noPrepare.createSavepoint("`"));
asserted.isThrownBy(() -> noPrepare.createSavepoint("name`"));
asserted.isThrownBy(() -> noPrepare.createSavepoint("nam`e"));
asserted.isThrownBy(() -> noPrepare.createSavepoint(null));
}

Expand All @@ -110,9 +107,6 @@ void badReleaseSavepoint() {
ThrowableTypeAssert<?> asserted = assertThatIllegalArgumentException();

asserted.isThrownBy(() -> noPrepare.releaseSavepoint(""));
asserted.isThrownBy(() -> noPrepare.releaseSavepoint("`"));
asserted.isThrownBy(() -> noPrepare.releaseSavepoint("name`"));
asserted.isThrownBy(() -> noPrepare.releaseSavepoint("nam`e"));
asserted.isThrownBy(() -> noPrepare.releaseSavepoint(null));
}

Expand All @@ -122,9 +116,6 @@ void badRollbackTransactionToSavepoint() {
ThrowableTypeAssert<?> asserted = assertThatIllegalArgumentException();

asserted.isThrownBy(() -> noPrepare.rollbackTransactionToSavepoint(""));
asserted.isThrownBy(() -> noPrepare.rollbackTransactionToSavepoint("`"));
asserted.isThrownBy(() -> noPrepare.rollbackTransactionToSavepoint("name`"));
asserted.isThrownBy(() -> noPrepare.rollbackTransactionToSavepoint("nam`e"));
asserted.isThrownBy(() -> noPrepare.rollbackTransactionToSavepoint(null));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ default void returnGeneratedValues() {
assertEquals(statement.generatedKeyName, "LAST_INSERT_ID");
statement.returnGeneratedValues("generated");
assertEquals(statement.generatedKeyName, "generated");
statement.returnGeneratedValues("generate`d");
assertEquals(statement.generatedKeyName, "generate`d");
}

@SuppressWarnings("ConstantConditions")
Expand All @@ -166,8 +168,6 @@ default void badReturnGeneratedValues() {
assertThrows(IllegalArgumentException.class, () -> statement.returnGeneratedValues((String) null));
assertThrows(IllegalArgumentException.class, () -> statement.returnGeneratedValues((String[]) null));
assertThrows(IllegalArgumentException.class, () -> statement.returnGeneratedValues(""));
assertThrows(IllegalArgumentException.class, () -> statement.returnGeneratedValues("`generating`"));
assertThrows(IllegalArgumentException.class, () -> statement.returnGeneratedValues("generating`"));
assertThrows(IllegalArgumentException.class, () ->
statement.returnGeneratedValues("generated", "names"));
}
Expand Down
Loading

0 comments on commit e01335f

Please sign in to comment.