Skip to content

Commit

Permalink
Merge pull request #28132 from ctailor2
Browse files Browse the repository at this point in the history
* pr/28132:
  Polish "Allow batch update to take a KeyHolder"
  Allow batch update to take a KeyHolder

Closes gh-28132
  • Loading branch information
snicoll committed Sep 15, 2023
2 parents 056de7e + c21a9b9 commit f628c60
Show file tree
Hide file tree
Showing 5 changed files with 252 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -997,6 +997,29 @@ <T> List<T> queryForList(String sql, Object[] args, int[] argTypes, Class<T> ele
*/
int[] batchUpdate(String sql, BatchPreparedStatementSetter pss) throws DataAccessException;

/**
* Issue multiple update statements on a single PreparedStatement,
* using batch updates and a BatchPreparedStatementSetter to set values.
* Generated keys will be put into the given KeyHolder.
* <p>Note that the given PreparedStatementCreator has to create a statement
* with activated extraction of generated keys (a JDBC 3.0 feature). This can
* either be done directly or through using a PreparedStatementCreatorFactory.
* <p>Will fall back to separate updates on a single PreparedStatement
* if the JDBC driver does not support batch updates.
* @param psc a callback that creates a PreparedStatement given a Connection
* @param pss object to set parameters on the PreparedStatement
* created by this method
* @param generatedKeyHolder a KeyHolder that will hold the generated keys
* @return an array of the number of rows affected by each statement
* (may also contain special JDBC-defined negative values for affected rows such as
* {@link java.sql.Statement#SUCCESS_NO_INFO}/{@link java.sql.Statement#EXECUTE_FAILED})
* @throws DataAccessException if there is any problem issuing the update
* @since 6.1
* @see org.springframework.jdbc.support.GeneratedKeyHolder
*/
int[] batchUpdate(PreparedStatementCreator psc, BatchPreparedStatementSetter pss,
KeyHolder generatedKeyHolder) throws DataAccessException;

/**
* Execute a batch using the supplied SQL statement with the batch of supplied arguments.
* @param sql the SQL statement to execute
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -996,21 +996,10 @@ public int update(final PreparedStatementCreator psc, final KeyHolder generatedK

return updateCount(execute(psc, ps -> {
int rows = ps.executeUpdate();
List<Map<String, Object>> generatedKeys = generatedKeyHolder.getKeyList();
generatedKeys.clear();
ResultSet keys = ps.getGeneratedKeys();
if (keys != null) {
try {
RowMapperResultSetExtractor<Map<String, Object>> rse =
new RowMapperResultSetExtractor<>(getColumnMapRowMapper(), 1);
generatedKeys.addAll(result(rse.extractData(keys)));
}
finally {
JdbcUtils.closeResultSet(keys);
}
}
generatedKeyHolder.getKeyList().clear();
storeGeneratedKeys(generatedKeyHolder, ps, 1);
if (logger.isTraceEnabled()) {
logger.trace("SQL update affected " + rows + " rows and returned " + generatedKeys.size() + " keys");
logger.trace("SQL update affected " + rows + " rows and returned " + generatedKeyHolder.getKeyList().size() + " keys");
}
return rows;
}, true));
Expand All @@ -1031,6 +1020,16 @@ public int update(String sql, @Nullable Object... args) throws DataAccessExcepti
return update(sql, newArgPreparedStatementSetter(args));
}

@Override
public int[] batchUpdate(final PreparedStatementCreator psc, final BatchPreparedStatementSetter pss,
final KeyHolder generatedKeyHolder) throws DataAccessException {

int[] result = execute(psc, getPreparedStatementCallback(pss, generatedKeyHolder));

Assert.state(result != null, "No result array");
return result;
}

@Override
public int[] batchUpdate(String sql, final BatchPreparedStatementSetter pss) throws DataAccessException {
if (logger.isDebugEnabled()) {
Expand All @@ -1041,43 +1040,7 @@ public int[] batchUpdate(String sql, final BatchPreparedStatementSetter pss) thr
return new int[0];
}

int[] result = execute(sql, (PreparedStatementCallback<int[]>) ps -> {
try {
InterruptibleBatchPreparedStatementSetter ipss =
(pss instanceof InterruptibleBatchPreparedStatementSetter ibpss ? ibpss : null);
if (JdbcUtils.supportsBatchUpdates(ps.getConnection())) {
for (int i = 0; i < batchSize; i++) {
pss.setValues(ps, i);
if (ipss != null && ipss.isBatchExhausted(i)) {
break;
}
ps.addBatch();
}
return ps.executeBatch();
}
else {
List<Integer> rowsAffected = new ArrayList<>();
for (int i = 0; i < batchSize; i++) {
pss.setValues(ps, i);
if (ipss != null && ipss.isBatchExhausted(i)) {
break;
}
rowsAffected.add(ps.executeUpdate());
}
int[] rowsAffectedArray = new int[rowsAffected.size()];
for (int i = 0; i < rowsAffectedArray.length; i++) {
rowsAffectedArray[i] = rowsAffected.get(i);
}
return rowsAffectedArray;
}
}
finally {
if (pss instanceof ParameterDisposer parameterDisposer) {
parameterDisposer.cleanupParameters();
}
}
});

int[] result = execute(sql, getPreparedStatementCallback(pss, null));
Assert.state(result != null, "No result array");
return result;
}
Expand Down Expand Up @@ -1604,6 +1567,71 @@ private static int updateCount(@Nullable Integer result) {
return result;
}

private void storeGeneratedKeys(KeyHolder generatedKeyHolder, PreparedStatement ps, int rowsExpected) throws SQLException {
List<Map<String, Object>> generatedKeys = generatedKeyHolder.getKeyList();
ResultSet keys = ps.getGeneratedKeys();
if (keys != null) {
try {
RowMapperResultSetExtractor<Map<String, Object>> rse =
new RowMapperResultSetExtractor<>(getColumnMapRowMapper(), rowsExpected);
generatedKeys.addAll(result(rse.extractData(keys)));
}
finally {
JdbcUtils.closeResultSet(keys);
}
}
}

private PreparedStatementCallback<int[]> getPreparedStatementCallback(BatchPreparedStatementSetter pss, @Nullable KeyHolder generatedKeyHolder) {
return ps -> {
try {
int batchSize = pss.getBatchSize();
InterruptibleBatchPreparedStatementSetter ipss =
(pss instanceof InterruptibleBatchPreparedStatementSetter ibpss ? ibpss : null);
if (generatedKeyHolder != null) {
generatedKeyHolder.getKeyList().clear();
}
if (JdbcUtils.supportsBatchUpdates(ps.getConnection())) {
for (int i = 0; i < batchSize; i++) {
pss.setValues(ps, i);
if (ipss != null && ipss.isBatchExhausted(i)) {
break;
}
ps.addBatch();
}
int[] results = ps.executeBatch();
if (generatedKeyHolder != null) {
storeGeneratedKeys(generatedKeyHolder, ps, batchSize);
}
return results;
}
else {
List<Integer> rowsAffected = new ArrayList<>();
for (int i = 0; i < batchSize; i++) {
pss.setValues(ps, i);
if (ipss != null && ipss.isBatchExhausted(i)) {
break;
}
rowsAffected.add(ps.executeUpdate());
if (generatedKeyHolder != null) {
storeGeneratedKeys(generatedKeyHolder, ps, 1);
}
}
int[] rowsAffectedArray = new int[rowsAffected.size()];
for (int i = 0; i < rowsAffectedArray.length; i++) {
rowsAffectedArray[i] = rowsAffected.get(i);
}
return rowsAffectedArray;
}
}
finally {
if (pss instanceof ParameterDisposer parameterDisposer) {
parameterDisposer.cleanupParameters();
}
}
};
}


/**
* Invocation handler that suppresses close calls on JDBC Connections.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -555,4 +555,37 @@ int update(String sql, SqlParameterSource paramSource, KeyHolder generatedKeyHol
*/
int[] batchUpdate(String sql, SqlParameterSource[] batchArgs);

/**
* Execute a batch using the supplied SQL statement with the batch of supplied
* arguments, returning generated keys.
* @param sql the SQL statement to execute
* @param batchArgs the array of {@link SqlParameterSource} containing the batch of
* arguments for the query
* @param generatedKeyHolder a {@link KeyHolder} that will hold the generated keys
* @return an array containing the numbers of rows affected by each update in the batch
* (may also contain special JDBC-defined negative values for affected rows such as
* {@link java.sql.Statement#SUCCESS_NO_INFO}/{@link java.sql.Statement#EXECUTE_FAILED})
* @throws DataAccessException if there is any problem issuing the update
* @since 6.1
* @see org.springframework.jdbc.support.GeneratedKeyHolder
*/
int[] batchUpdate(String sql, SqlParameterSource[] batchArgs, KeyHolder generatedKeyHolder);

/**
* Execute a batch using the supplied SQL statement with the batch of supplied arguments,
* returning generated keys.
* @param sql the SQL statement to execute
* @param batchArgs the array of {@link SqlParameterSource} containing the batch of
* arguments for the query
* @param generatedKeyHolder a {@link KeyHolder} that will hold the generated keys
* @param keyColumnNames names of the columns that will have keys generated for them
* @return an array containing the numbers of rows affected by each update in the batch
* (may also contain special JDBC-defined negative values for affected rows such as
* {@link java.sql.Statement#SUCCESS_NO_INFO}/{@link java.sql.Statement#EXECUTE_FAILED})
* @throws DataAccessException if there is any problem issuing the update
* @since 6.1
* @see org.springframework.jdbc.support.GeneratedKeyHolder
*/
int[] batchUpdate(String sql, SqlParameterSource[] batchArgs, KeyHolder generatedKeyHolder,
String[] keyColumnNames);
}
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,44 @@ public int getBatchSize() {
});
}

@Override
public int[] batchUpdate(String sql, SqlParameterSource[] batchArgs, KeyHolder generatedKeyHolder) {
return batchUpdate(sql, batchArgs, generatedKeyHolder, null);
}

@Override
public int[] batchUpdate(String sql, SqlParameterSource[] batchArgs, KeyHolder generatedKeyHolder,
@Nullable String[] keyColumnNames) {

if (batchArgs.length == 0) {
return new int[0];
}

ParsedSql parsedSql = getParsedSql(sql);
SqlParameterSource paramSource = batchArgs[0];
PreparedStatementCreatorFactory pscf = getPreparedStatementCreatorFactory(parsedSql, paramSource);
if (keyColumnNames != null) {
pscf.setGeneratedKeysColumnNames(keyColumnNames);
}
else {
pscf.setReturnGeneratedKeys(true);
}
Object[] params = NamedParameterUtils.buildValueArray(parsedSql, paramSource, null);
PreparedStatementCreator psc = pscf.newPreparedStatementCreator(params);
return getJdbcOperations().batchUpdate(psc, new BatchPreparedStatementSetter() {
@Override
public void setValues(PreparedStatement ps, int i) throws SQLException {
Object[] values = NamedParameterUtils.buildValueArray(parsedSql, batchArgs[i], null);
pscf.newPreparedStatementSetter(values).setValues(ps);
}

@Override
public int getBatchSize() {
return batchArgs.length;
}
}, generatedKeyHolder);
}


/**
* Build a {@link PreparedStatementCreator} based on the given SQL and named parameters.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
import org.springframework.jdbc.core.support.AbstractInterruptibleBatchPreparedStatementSetter;
import org.springframework.jdbc.datasource.ConnectionProxy;
import org.springframework.jdbc.datasource.SingleConnectionDataSource;
import org.springframework.jdbc.support.GeneratedKeyHolder;
import org.springframework.jdbc.support.KeyHolder;
import org.springframework.jdbc.support.SQLErrorCodeSQLExceptionTranslator;
import org.springframework.jdbc.support.SQLStateSQLExceptionTranslator;
import org.springframework.util.LinkedCaseInsensitiveMap;
Expand Down Expand Up @@ -1104,6 +1106,83 @@ public void testEquallyNamedColumn() throws SQLException {
assertThat(map.get("x")).isEqualTo("first value");
}

@Test
void testBatchUpdateReturnsGeneratedKeys_whenDatabaseSupportsBatchUpdates() throws SQLException {
final int[] rowsAffected = new int[] {1, 2};
given(this.preparedStatement.executeBatch()).willReturn(rowsAffected);
DatabaseMetaData databaseMetaData = mock(DatabaseMetaData.class);
given(databaseMetaData.supportsBatchUpdates()).willReturn(true);
given(this.connection.getMetaData()).willReturn(databaseMetaData);
ResultSet generatedKeysResultSet = mock(ResultSet.class);
ResultSetMetaData rsmd = mock(ResultSetMetaData.class);
given(rsmd.getColumnCount()).willReturn(1);
given(rsmd.getColumnLabel(1)).willReturn("someId");
given(generatedKeysResultSet.getMetaData()).willReturn(rsmd);
given(generatedKeysResultSet.getObject(1)).willReturn(123, 456);
given(generatedKeysResultSet.next()).willReturn(true, true, false);
given(this.preparedStatement.getGeneratedKeys()).willReturn(generatedKeysResultSet);

int[] values = new int[]{100, 200};
BatchPreparedStatementSetter bpss = new BatchPreparedStatementSetter() {
@Override
public void setValues(PreparedStatement ps, int i) throws SQLException {
ps.setObject(i, values[i]);
}

@Override
public int getBatchSize() {
return 2;
}
};

KeyHolder keyHolder = new GeneratedKeyHolder();
this.template.batchUpdate(con -> con.prepareStatement(""), bpss, keyHolder);

assertThat(keyHolder.getKeyList()).containsExactly(
Collections.singletonMap("someId", 123),
Collections.singletonMap("someId", 456));
}

@Test
void testBatchUpdateReturnsGeneratedKeys_whenDatabaseDoesNotSupportBatchUpdates() throws SQLException {
final int[] rowsAffected = new int[] {1, 2};
given(this.preparedStatement.executeBatch()).willReturn(rowsAffected);
DatabaseMetaData databaseMetaData = mock(DatabaseMetaData.class);
given(databaseMetaData.supportsBatchUpdates()).willReturn(false);
given(this.connection.getMetaData()).willReturn(databaseMetaData);
ResultSetMetaData rsmd = mock(ResultSetMetaData.class);
given(rsmd.getColumnCount()).willReturn(1);
given(rsmd.getColumnLabel(1)).willReturn("someId");
ResultSet generatedKeysResultSet1 = mock(ResultSet.class);
given(generatedKeysResultSet1.getMetaData()).willReturn(rsmd);
given(generatedKeysResultSet1.getObject(1)).willReturn(123);
given(generatedKeysResultSet1.next()).willReturn(true, false);
ResultSet generatedKeysResultSet2 = mock(ResultSet.class);
given(generatedKeysResultSet2.getMetaData()).willReturn(rsmd);
given(generatedKeysResultSet2.getObject(1)).willReturn(456);
given(generatedKeysResultSet2.next()).willReturn(true, false);
given(this.preparedStatement.getGeneratedKeys()).willReturn(generatedKeysResultSet1, generatedKeysResultSet2);

int[] values = new int[]{100, 200};
BatchPreparedStatementSetter bpss = new BatchPreparedStatementSetter() {
@Override
public void setValues(PreparedStatement ps, int i) throws SQLException {
ps.setObject(i, values[i]);
}

@Override
public int getBatchSize() {
return 2;
}
};

KeyHolder keyHolder = new GeneratedKeyHolder();
this.template.batchUpdate(con -> con.prepareStatement(""), bpss, keyHolder);

assertThat(keyHolder.getKeyList()).containsExactly(
Collections.singletonMap("someId", 123),
Collections.singletonMap("someId", 456));
}

private void mockDatabaseMetaData(boolean supportsBatchUpdates) throws SQLException {
DatabaseMetaData databaseMetaData = mock();
Expand Down

0 comments on commit f628c60

Please sign in to comment.