Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FMWK-276 Fix parameterized queries in PreparedStatement #50

Merged
merged 1 commit into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions src/main/java/com/aerospike/jdbc/AerospikeDatabaseMetadata.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import com.aerospike.jdbc.sql.ListRecordSet;
import com.aerospike.jdbc.sql.SimpleWrapper;
import com.aerospike.jdbc.util.AerospikeUtils;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;

import java.io.IOException;
import java.io.StringReader;
Expand All @@ -22,6 +24,7 @@
import java.sql.Statement;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.function.Function;
import java.util.logging.Logger;
import java.util.regex.Pattern;
Expand Down Expand Up @@ -49,14 +52,15 @@ public class AerospikeDatabaseMetadata implements DatabaseMetaData, SimpleWrappe
private static final String NEW_LINE = System.lineSeparator();

private final String url;
private final AerospikeConnection connection;
private final Connection connection;
private final String dbBuild;
private final String dbEdition;
private final List<String> catalogs;
private final Map<String, Collection<String>> tables;
private final Map<String, Collection<AerospikeSecondaryIndex>> catalogIndexes;
private final Map<String, AerospikeSecondaryIndex> secondaryIndexes;
private final AerospikeSchemaBuilder schemaBuilder;
private final Cache<String, ResultSetMetaData> resultSetMetaDataCache;

public AerospikeDatabaseMetadata(String url, IAerospikeClient client, AerospikeConnection connection) {
logger.info("Init AerospikeDatabaseMetadata");
Expand Down Expand Up @@ -103,6 +107,7 @@ public AerospikeDatabaseMetadata(String url, IAerospikeClient client, AerospikeC
.collect(Collectors.toMap(AerospikeSecondaryIndex::toKey, Function.identity()));

schemaBuilder = new AerospikeSchemaBuilder(client, connection.getConfiguration().getDriverPolicy());
resultSetMetaDataCache = CacheBuilder.newBuilder().build();

dbBuild = join("N/A", ", ", builds);
dbEdition = join("Aerospike", ", ", editions);
Expand Down Expand Up @@ -1304,13 +1309,19 @@ private int ordinal(ResultSetMetaData md, String columnName) {
}

private ResultSetMetaData getMetadata(String namespace, String table) {
try (Statement statement = connection.createStatement()) {
String query = format("SELECT * FROM \"%s.%s\" LIMIT %d", namespace, table,
connection.getConfiguration().getDriverPolicy().getSchemaBuilderMaxRecords());
return statement.executeQuery(query).getMetaData();
} catch (SQLException e) {
logger.severe(() -> format("Exception in getMetadata, namespace: %s, table: %s", namespace, table));
throw new IllegalArgumentException(e);
final String key = format("%s.%s", namespace, table);
try {
return resultSetMetaDataCache.get(key, () -> {
try (Statement statement = connection.createStatement()) {
String query = format("SELECT * FROM \"%s.%s\" LIMIT 1", namespace, table);
return statement.executeQuery(query).getMetaData();
} catch (SQLException e) {
logger.severe(() -> format("Exception in getMetadata, namespace: %s, table: %s", namespace, table));
throw new IllegalArgumentException(e);
}
});
} catch (ExecutionException e) {
throw new IllegalArgumentException(e.getCause());
}
}

Expand Down
79 changes: 44 additions & 35 deletions src/main/java/com/aerospike/jdbc/AerospikePreparedStatement.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.aerospike.client.Value;
import com.aerospike.jdbc.model.AerospikeQuery;
import com.aerospike.jdbc.model.DataColumn;
import com.aerospike.jdbc.model.QueryType;
import com.aerospike.jdbc.sql.AerospikeResultSetMetaData;
import com.aerospike.jdbc.sql.SimpleParameterMetaData;
import com.aerospike.jdbc.sql.type.ByteArrayBlob;
Expand All @@ -26,43 +27,40 @@

import static com.aerospike.jdbc.util.PreparedStatement.parseParameters;
import static java.lang.String.format;
import static java.util.Objects.isNull;

public class AerospikePreparedStatement extends AerospikeStatement implements PreparedStatement {

private static final Logger logger = Logger.getLogger(AerospikePreparedStatement.class.getName());

private final String sql;
private final AerospikeConnection connection;
private final Object[] parameterValues;
private final AerospikeQuery query;
private final String sqlStatement;
private final Object[] sqlParameters;

public AerospikePreparedStatement(IAerospikeClient client, AerospikeConnection connection, String sql) {
public AerospikePreparedStatement(IAerospikeClient client, AerospikeConnection connection, String sqlStatement) {
super(client, connection);
this.sql = sql;
this.connection = connection;
parameterValues = buildParameterValues(sql);
try {
query = parseQuery(sql);
} catch (SQLException e) {
throw new UnsupportedOperationException(e);
}
this.sqlStatement = sqlStatement;
sqlParameters = buildSqlParameters(sqlStatement);
logger.info(() -> format("statement: %s, params: %d", sqlStatement, sqlParameters.length));
}

private Object[] buildParameterValues(String sql) {
private Object[] buildSqlParameters(String sql) {
int params = parseParameters(sql, 0).getValue();
return new Object[params];
}

@Override
public ResultSet executeQuery() throws SQLException {
logger.info("AerospikePreparedStatement executeQuery");
return super.executeQuery(sql);
String preparedQueryString = prepareQueryString();
logger.info(() -> "executeQuery: " + preparedQueryString);
AerospikeQuery query = parseQuery(preparedQueryString);
runQuery(query);
return resultSet;
}

@Override
public int executeUpdate() throws SQLException {
logger.info("AerospikePreparedStatement executeUpdate");
return super.executeUpdate(sql);
executeQuery();
return updateCount;
}

@Override
Expand Down Expand Up @@ -116,7 +114,7 @@ public void setBigDecimal(int parameterIndex, BigDecimal x) throws SQLException

@Override
public void setString(int parameterIndex, String x) throws SQLException {
setObject(parameterIndex, "\"" + x + "\"");
setObject(parameterIndex, format("\"%s\"", x));
}

@Override
Expand Down Expand Up @@ -149,6 +147,7 @@ public void setAsciiStream(int parameterIndex, InputStream x, int length) throws
*/
@Override
@Deprecated
@SuppressWarnings("java:S1133")
public void setUnicodeStream(int parameterIndex, InputStream x, int length) throws SQLException {
throw new SQLFeatureNotSupportedException("setUnicodeStream is deprecated");
}
Expand All @@ -160,7 +159,7 @@ public void setBinaryStream(int parameterIndex, InputStream x, int length) throw

@Override
public void clearParameters() {
Arrays.fill(parameterValues, null);
Arrays.fill(sqlParameters, null);
}

@Override
Expand All @@ -170,28 +169,36 @@ public void setObject(int parameterIndex, Object x, int targetSqlType) throws SQ

@Override
public void setObject(int parameterIndex, Object x) throws SQLException {
if (parameterIndex <= 0 || parameterIndex > parameterValues.length) {
throw new SQLException(parameterValues.length == 0 ?
"Current SQL statement does not have parameters" :
format("Wrong parameter index. Expected from %d till %d", 1, parameterValues.length));
if (parameterIndex <= 0 || parameterIndex > sqlParameters.length) {
throw new SQLDataException(sqlParameters.length == 0
? "Current SQL statement does not have parameters"
: format("The parameter index %d is out of range, number of parameters: %d",
parameterIndex, sqlParameters.length));
}
parameterValues[parameterIndex - 1] = x;
sqlParameters[parameterIndex - 1] = x;
}

@Override
public boolean execute() throws SQLException {
String preparedQuery = prepareQuery();
logger.info(preparedQuery);
return execute(preparedQuery);
}

private String prepareQuery() {
return format(this.sql.replace("?", "%s"), parameterValues);
String preparedQueryString = prepareQueryString();
logger.info(() -> "execute: " + preparedQueryString);
AerospikeQuery query = parseQuery(preparedQueryString);
runQuery(query);
return query.getQueryType() == QueryType.SELECT;
}

private String prepareQueryString() {
String preparedQueryString = sqlStatement;
for (Object value : sqlParameters) {
String replacement = isNull(value) ? "?" : value.toString();
preparedQueryString = preparedQueryString.replaceFirst("\\?", replacement);
}
return preparedQueryString;
}

@Override
public void addBatch() throws SQLException {
addBatch(sql);
addBatch(prepareQueryString());
}

@Override
Expand Down Expand Up @@ -221,6 +228,7 @@ public void setArray(int parameterIndex, Array x) throws SQLException {

@Override
public ResultSetMetaData getMetaData() throws SQLException {
AerospikeQuery query = parseQuery(prepareQueryString());
List<DataColumn> columns = ((AerospikeDatabaseMetadata) connection.getMetaData())
.getSchemaBuilder()
.getSchema(query.getSchemaTable());
Expand Down Expand Up @@ -254,6 +262,7 @@ public void setURL(int parameterIndex, URL url) throws SQLException {

@Override
public ParameterMetaData getParameterMetaData() throws SQLException {
AerospikeQuery query = parseQuery(prepareQueryString());
List<DataColumn> columns = ((AerospikeDatabaseMetadata) connection.getMetaData())
.getSchemaBuilder()
.getSchema(query.getSchemaTable());
Expand Down Expand Up @@ -297,9 +306,9 @@ public void setClob(int parameterIndex, Reader reader, long length) throws SQLEx
@Override
public void setBlob(int parameterIndex, InputStream inputStream, long length) throws SQLException {
byte[] bytes = new byte[(int) length];
DataInputStream dis = new DataInputStream(inputStream);
DataInputStream dataInputStream = new DataInputStream(inputStream);
try {
dis.readFully(bytes);
dataInputStream.readFully(bytes);
if (inputStream.read() != -1) {
throw new SQLException(format("Source contains more bytes than required %d", length));
}
Expand Down
21 changes: 14 additions & 7 deletions src/main/java/com/aerospike/jdbc/AerospikeStatement.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.aerospike.client.IAerospikeClient;
import com.aerospike.jdbc.model.AerospikeQuery;
import com.aerospike.jdbc.model.Pair;
import com.aerospike.jdbc.model.QueryType;
import com.aerospike.jdbc.query.QueryPerformer;
import com.aerospike.jdbc.sql.SimpleWrapper;
import com.aerospike.jdbc.util.AuxStatementParser;
Expand All @@ -29,12 +30,14 @@ public class AerospikeStatement implements Statement, SimpleWrapper {
private static final String AUTO_GENERATED_KEYS_NOT_SUPPORTED_MESSAGE = "Auto-generated keys are not supported";

protected final IAerospikeClient client;
private final Connection connection;
protected final AerospikeConnection connection;

protected String schema;
protected ResultSet resultSet;
protected int updateCount;

private int maxRows = Integer.MAX_VALUE;
private int queryTimeout;
private ResultSet resultSet;
private int updateCount;

public AerospikeStatement(IAerospikeClient client, AerospikeConnection connection) {
this.client = client;
Expand All @@ -50,12 +53,14 @@ public AerospikeStatement(IAerospikeClient client, AerospikeConnection connectio
public ResultSet executeQuery(String sql) throws SQLException {
logger.info(() -> "executeQuery: " + sql);
AerospikeQuery query = parseQuery(sql);
runQuery(query);
return resultSet;
}

protected void runQuery(AerospikeQuery query) {
Pair<ResultSet, Integer> result = QueryPerformer.executeQuery(client, this, query);
resultSet = result.getLeft();
updateCount = result.getRight();

return resultSet;
}

protected AerospikeQuery parseQuery(String sql) throws SQLException {
Expand Down Expand Up @@ -140,8 +145,10 @@ public void setCursorName(String name) throws SQLException {

@Override
public boolean execute(String sql) throws SQLException {
resultSet = executeQuery(sql);
return true;
logger.info(() -> "execute: " + sql);
AerospikeQuery query = parseQuery(sql);
runQuery(query);
return query.getQueryType() == QueryType.SELECT;
}

@Override
Expand Down
10 changes: 6 additions & 4 deletions src/main/java/com/aerospike/jdbc/model/DriverPolicy.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

public class DriverPolicy {

private static final int DEFAULT_CAPACITY = 256;
private static final int DEFAULT_TIMEOUT_MS = 1000;
private static final int DEFAULT_RECORD_SET_QUEUE_CAPACITY = 256;
private static final int DEFAULT_RECORD_SET_TIMEOUT_MS = 1000;
private static final int DEFAULT_METADATA_CACHE_TTL_SECONDS = 3600;
private static final int DEFAULT_SCHEMA_BUILDER_MAX_RECORDS = 1000;

Expand All @@ -15,8 +15,10 @@ public class DriverPolicy {
private final int schemaBuilderMaxRecords;

public DriverPolicy(Properties properties) {
recordSetQueueCapacity = parseInt(properties.getProperty("recordSetQueueCapacity"), DEFAULT_CAPACITY);
recordSetTimeoutMs = parseInt(properties.getProperty("recordSetTimeoutMs"), DEFAULT_TIMEOUT_MS);
recordSetQueueCapacity = parseInt(properties.getProperty("recordSetQueueCapacity"),
DEFAULT_RECORD_SET_QUEUE_CAPACITY);
recordSetTimeoutMs = parseInt(properties.getProperty("recordSetTimeoutMs"),
DEFAULT_RECORD_SET_TIMEOUT_MS);
metadataCacheTtlSeconds = parseInt(properties.getProperty("metadataCacheTtlSeconds"),
DEFAULT_METADATA_CACHE_TTL_SECONDS);
schemaBuilderMaxRecords = parseInt(properties.getProperty("schemaBuilderMaxRecords"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ default InputStream getAsciiStream(int columnIndex) throws SQLException {
*/
@Override
@Deprecated
@SuppressWarnings("java:S1133")
default InputStream getUnicodeStream(int columnIndex) throws SQLException {
return getUnicodeStream(getColumnLabel(columnIndex));
}
Expand Down
10 changes: 4 additions & 6 deletions src/test/java/com/aerospike/jdbc/DatabaseMetadataTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import java.util.Objects;

import static com.aerospike.jdbc.util.TestUtil.closeQuietly;
import static java.lang.String.format;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertTrue;
Expand All @@ -25,10 +26,8 @@ public void setUp() throws SQLException {
Objects.requireNonNull(connection, "connection is null");
PreparedStatement statement = null;
int count;
String query = String.format(
"insert into %s (bin1, int1, str1, bool1) values (11100, 1, \"bar\", true)",
tableName
);
String query = format("insert into %s (bin1, int1, str1, bool1) values (11100, 1, \"bar\", true)",
tableName);
try {
statement = connection.prepareStatement(query);
count = statement.executeUpdate();
Expand All @@ -43,7 +42,7 @@ public void tearDown() throws SQLException {
Objects.requireNonNull(connection, "connection is null");
PreparedStatement statement = null;
ResultSet resultSet = null;
String query = String.format("delete from %s", tableName);
String query = format("delete from %s", tableName);
try {
statement = connection.prepareStatement(query);
resultSet = statement.executeQuery();
Expand All @@ -57,7 +56,6 @@ public void tearDown() throws SQLException {
@Test
public void testGetTables() throws SQLException {
DatabaseMetaData databaseMetaData = connection.getMetaData();

ResultSet rs = databaseMetaData.getTables(namespace, namespace, tableName, null);

assertTrue(rs.next());
Expand Down
Loading