Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch update return generated keys #28132

Original file line number Diff line number Diff line change
Expand Up @@ -990,6 +990,27 @@ <T> List<T> queryForList(String sql, Object[] args, int[] argTypes, Class<T> ele
*/
int[] batchUpdate(String sql, BatchPreparedStatementSetter pss) throws DataAccessException;

/**
* Issue multiple update statements on a single PreparedStatement,
* using batch updates and a BatchPreparedStatementSetter to set values.
* Generated keys will be put into the given KeyHolder.
* <p>Note that the given PreparedStatementCreator has to create a statement
* with activated extraction of generated keys (a JDBC 3.0 feature). This can
* either be done directly or through using a PreparedStatementCreatorFactory.
* <p>Will fall back to separate updates on a single PreparedStatement
* if the JDBC driver does not support batch updates.
* @param psc a callback that creates a PreparedStatement given a Connection
* @param pss object to set parameters on the PreparedStatement
* created by this method
* @param generatedKeyHolder a KeyHolder that will hold the generated keys
* @return an array of the number of rows affected by each statement
* (may also contain special JDBC-defined negative values for affected rows such as
* {@link java.sql.Statement#SUCCESS_NO_INFO}/{@link java.sql.Statement#EXECUTE_FAILED})
* @throws DataAccessException if there is any problem issuing the update
* @see org.springframework.jdbc.support.GeneratedKeyHolder
*/
int[] batchUpdate(PreparedStatementCreator psc, BatchPreparedStatementSetter pss, KeyHolder generatedKeyHolder) throws DataAccessException;

/**
* Execute a batch using the supplied SQL statement with the batch of supplied arguments.
* @param sql the SQL statement to execute
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -990,21 +990,10 @@ public int update(final PreparedStatementCreator psc, final KeyHolder generatedK

return updateCount(execute(psc, ps -> {
int rows = ps.executeUpdate();
List<Map<String, Object>> generatedKeys = generatedKeyHolder.getKeyList();
generatedKeys.clear();
ResultSet keys = ps.getGeneratedKeys();
if (keys != null) {
try {
RowMapperResultSetExtractor<Map<String, Object>> rse =
new RowMapperResultSetExtractor<>(getColumnMapRowMapper(), 1);
generatedKeys.addAll(result(rse.extractData(keys)));
}
finally {
JdbcUtils.closeResultSet(keys);
}
}
generatedKeyHolder.getKeyList().clear();
storeGeneratedKeys(generatedKeyHolder, ps, 1);
if (logger.isTraceEnabled()) {
logger.trace("SQL update affected " + rows + " rows and returned " + generatedKeys.size() + " keys");
logger.trace("SQL update affected " + rows + " rows and returned " + generatedKeyHolder.getKeyList().size() + " keys");
}
return rows;
}, true));
Expand All @@ -1025,50 +1014,21 @@ public int update(String sql, @Nullable Object... args) throws DataAccessExcepti
return update(sql, newArgPreparedStatementSetter(args));
}

@Override
public int[] batchUpdate(final PreparedStatementCreator psc, final BatchPreparedStatementSetter pss, final KeyHolder generatedKeyHolder) throws DataAccessException {
int[] result = execute(psc, getPreparedStatementCallback(pss, generatedKeyHolder));

Assert.state(result != null, "No result array");
return result;
}

@Override
public int[] batchUpdate(String sql, final BatchPreparedStatementSetter pss) throws DataAccessException {
if (logger.isDebugEnabled()) {
logger.debug("Executing SQL batch update [" + sql + "]");
}

int[] result = execute(sql, (PreparedStatementCallback<int[]>) ps -> {
try {
int batchSize = pss.getBatchSize();
InterruptibleBatchPreparedStatementSetter ipss =
(pss instanceof InterruptibleBatchPreparedStatementSetter ?
(InterruptibleBatchPreparedStatementSetter) pss : null);
if (JdbcUtils.supportsBatchUpdates(ps.getConnection())) {
for (int i = 0; i < batchSize; i++) {
pss.setValues(ps, i);
if (ipss != null && ipss.isBatchExhausted(i)) {
break;
}
ps.addBatch();
}
return ps.executeBatch();
}
else {
List<Integer> rowsAffected = new ArrayList<>();
for (int i = 0; i < batchSize; i++) {
pss.setValues(ps, i);
if (ipss != null && ipss.isBatchExhausted(i)) {
break;
}
rowsAffected.add(ps.executeUpdate());
}
int[] rowsAffectedArray = new int[rowsAffected.size()];
for (int i = 0; i < rowsAffectedArray.length; i++) {
rowsAffectedArray[i] = rowsAffected.get(i);
}
return rowsAffectedArray;
}
}
finally {
if (pss instanceof ParameterDisposer) {
((ParameterDisposer) pss).cleanupParameters();
}
}
});
int[] result = execute(sql, getPreparedStatementCallback(pss, null));

Assert.state(result != null, "No result array");
return result;
Expand Down Expand Up @@ -1567,6 +1527,72 @@ private static int updateCount(@Nullable Integer result) {
return result;
}

private void storeGeneratedKeys(KeyHolder generatedKeyHolder, PreparedStatement ps, int rowsExpected) throws SQLException {
List<Map<String, Object>> generatedKeys = generatedKeyHolder.getKeyList();
ResultSet keys = ps.getGeneratedKeys();
if (keys != null) {
try {
RowMapperResultSetExtractor<Map<String, Object>> rse =
new RowMapperResultSetExtractor<>(getColumnMapRowMapper(), rowsExpected);
generatedKeys.addAll(result(rse.extractData(keys)));
}
finally {
JdbcUtils.closeResultSet(keys);
}
}
}

private PreparedStatementCallback<int[]> getPreparedStatementCallback(BatchPreparedStatementSetter pss, @Nullable KeyHolder generatedKeyHolder) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest a more descriptive name:

Suggested change
private PreparedStatementCallback<int[]> getPreparedStatementCallback(BatchPreparedStatementSetter pss, @Nullable KeyHolder generatedKeyHolder) {
private PreparedStatementCallback<int[]> getPreparedStatementCallbackForBatchUpdate(BatchPreparedStatementSetter pss, @Nullable KeyHolder generatedKeyHolder) {

(Of course, also adjust this where it's used.)

return ps -> {
try {
int batchSize = pss.getBatchSize();
InterruptibleBatchPreparedStatementSetter ipss =
(pss instanceof InterruptibleBatchPreparedStatementSetter ?
(InterruptibleBatchPreparedStatementSetter) pss : null);
if (generatedKeyHolder != null) {
generatedKeyHolder.getKeyList().clear();
}
if (JdbcUtils.supportsBatchUpdates(ps.getConnection())) {
for (int i = 0; i < batchSize; i++) {
pss.setValues(ps, i);
if (ipss != null && ipss.isBatchExhausted(i)) {
break;
}
ps.addBatch();
}
int[] results = ps.executeBatch();
if (generatedKeyHolder != null) {
storeGeneratedKeys(generatedKeyHolder, ps, batchSize);
}
return results;
}
else {
List<Integer> rowsAffected = new ArrayList<>();
for (int i = 0; i < batchSize; i++) {
pss.setValues(ps, i);
if (ipss != null && ipss.isBatchExhausted(i)) {
break;
}
rowsAffected.add(ps.executeUpdate());
if (generatedKeyHolder != null) {
storeGeneratedKeys(generatedKeyHolder, ps, 1);
}
}
int[] rowsAffectedArray = new int[rowsAffected.size()];
for (int i = 0; i < rowsAffectedArray.length; i++) {
rowsAffectedArray[i] = rowsAffected.get(i);
}
return rowsAffectedArray;
}
}
finally {
if (pss instanceof ParameterDisposer) {
((ParameterDisposer) pss).cleanupParameters();
}
}
};
}


/**
* Invocation handler that suppresses close calls on JDBC Connections.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -549,4 +549,34 @@ int update(String sql, SqlParameterSource paramSource, KeyHolder generatedKeyHol
*/
int[] batchUpdate(String sql, SqlParameterSource[] batchArgs);

/**
* Execute a batch using the supplied SQL statement with the batch of supplied arguments,
* returning generated keys.
* @param sql the SQL statement to execute
* @param batchArgs the array of {@link SqlParameterSource} containing the batch of
* arguments for the query
* @param generatedKeyHolder a {@link KeyHolder} that will hold the generated keys
* @return an array containing the numbers of rows affected by each update in the batch
* (may also contain special JDBC-defined negative values for affected rows such as
* {@link java.sql.Statement#SUCCESS_NO_INFO}/{@link java.sql.Statement#EXECUTE_FAILED})
* @throws DataAccessException if there is any problem issuing the update
* @see org.springframework.jdbc.support.GeneratedKeyHolder
*/
int[] batchUpdate(String sql, SqlParameterSource[] batchArgs, KeyHolder generatedKeyHolder);

/**
* Execute a batch using the supplied SQL statement with the batch of supplied arguments,
* returning generated keys.
* @param sql the SQL statement to execute
* @param batchArgs the array of {@link SqlParameterSource} containing the batch of
* arguments for the query
* @param generatedKeyHolder a {@link KeyHolder} that will hold the generated keys
* @param keyColumnNames names of the columns that will have keys generated for them
* @return an array containing the numbers of rows affected by each update in the batch
* (may also contain special JDBC-defined negative values for affected rows such as
* {@link java.sql.Statement#SUCCESS_NO_INFO}/{@link java.sql.Statement#EXECUTE_FAILED})
* @throws DataAccessException if there is any problem issuing the update
* @see org.springframework.jdbc.support.GeneratedKeyHolder
*/
int[] batchUpdate(String sql, SqlParameterSource[] batchArgs, KeyHolder generatedKeyHolder, String[] keyColumnNames);
}
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,44 @@ public int getBatchSize() {
});
}

@Override
public int[] batchUpdate(String sql, SqlParameterSource[] batchArgs, KeyHolder generatedKeyHolder) {
return batchUpdate(sql, batchArgs, generatedKeyHolder, null);
}

@Override
public int[] batchUpdate(String sql, SqlParameterSource[] batchArgs, KeyHolder generatedKeyHolder, String[] keyColumnNames) {
if (batchArgs.length == 0) {
return new int[0];
}

ParsedSql parsedSql = getParsedSql(sql);
SqlParameterSource paramSource = batchArgs[0];
PreparedStatementCreatorFactory pscf = getPreparedStatementCreatorFactory(parsedSql, paramSource);
if (keyColumnNames != null) {
pscf.setGeneratedKeysColumnNames(keyColumnNames);
}
else {
pscf.setReturnGeneratedKeys(true);
}
Object[] params = NamedParameterUtils.buildValueArray(parsedSql, paramSource, null);
PreparedStatementCreator psc = pscf.newPreparedStatementCreator(params);
return getJdbcOperations().batchUpdate(
psc,
new BatchPreparedStatementSetter() {
@Override
public void setValues(PreparedStatement ps, int i) throws SQLException {
Object[] values = NamedParameterUtils.buildValueArray(parsedSql, batchArgs[i], null);
pscf.newPreparedStatementSetter(values).setValues(ps);
}
@Override
public int getBatchSize() {
return batchArgs.length;
}
},
generatedKeyHolder);
}


/**
* Build a {@link PreparedStatementCreator} based on the given SQL and named parameters.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
import org.springframework.jdbc.core.support.AbstractInterruptibleBatchPreparedStatementSetter;
import org.springframework.jdbc.datasource.ConnectionProxy;
import org.springframework.jdbc.datasource.SingleConnectionDataSource;
import org.springframework.jdbc.support.GeneratedKeyHolder;
import org.springframework.jdbc.support.KeyHolder;
import org.springframework.jdbc.support.SQLErrorCodeSQLExceptionTranslator;
import org.springframework.jdbc.support.SQLStateSQLExceptionTranslator;
import org.springframework.util.LinkedCaseInsensitiveMap;
Expand Down Expand Up @@ -1085,6 +1087,83 @@ public void testEquallyNamedColumn() throws SQLException {
assertThat(map.get("x")).isEqualTo("first value");
}

@Test
void testBatchUpdateReturnsGeneratedKeys_whenDatabaseSupportsBatchUpdates() throws SQLException {
final int[] rowsAffected = new int[] {1, 2};
given(this.preparedStatement.executeBatch()).willReturn(rowsAffected);
DatabaseMetaData databaseMetaData = mock(DatabaseMetaData.class);
given(databaseMetaData.supportsBatchUpdates()).willReturn(true);
given(this.connection.getMetaData()).willReturn(databaseMetaData);
ResultSet generatedKeysResultSet = mock(ResultSet.class);
ResultSetMetaData rsmd = mock(ResultSetMetaData.class);
given(rsmd.getColumnCount()).willReturn(1);
given(rsmd.getColumnLabel(1)).willReturn("someId");
given(generatedKeysResultSet.getMetaData()).willReturn(rsmd);
given(generatedKeysResultSet.getObject(1)).willReturn(123, 456);
given(generatedKeysResultSet.next()).willReturn(true, true, false);
given(this.preparedStatement.getGeneratedKeys()).willReturn(generatedKeysResultSet);

int[] values = new int[]{100, 200};
BatchPreparedStatementSetter bpss = new BatchPreparedStatementSetter() {
@Override
public void setValues(PreparedStatement ps, int i) throws SQLException {
ps.setObject(i, values[i]);
}

@Override
public int getBatchSize() {
return 2;
}
};

KeyHolder keyHolder = new GeneratedKeyHolder();
this.template.batchUpdate(con -> con.prepareStatement(""), bpss, keyHolder);

assertThat(keyHolder.getKeyList()).containsExactly(
Collections.singletonMap("someId", 123),
Collections.singletonMap("someId", 456));
}

@Test
void testBatchUpdateReturnsGeneratedKeys_whenDatabaseDoesNotSupportBatchUpdates() throws SQLException {
final int[] rowsAffected = new int[] {1, 2};
given(this.preparedStatement.executeBatch()).willReturn(rowsAffected);
DatabaseMetaData databaseMetaData = mock(DatabaseMetaData.class);
given(databaseMetaData.supportsBatchUpdates()).willReturn(false);
given(this.connection.getMetaData()).willReturn(databaseMetaData);
ResultSetMetaData rsmd = mock(ResultSetMetaData.class);
given(rsmd.getColumnCount()).willReturn(1);
given(rsmd.getColumnLabel(1)).willReturn("someId");
ResultSet generatedKeysResultSet1 = mock(ResultSet.class);
given(generatedKeysResultSet1.getMetaData()).willReturn(rsmd);
given(generatedKeysResultSet1.getObject(1)).willReturn(123);
given(generatedKeysResultSet1.next()).willReturn(true, false);
ResultSet generatedKeysResultSet2 = mock(ResultSet.class);
given(generatedKeysResultSet2.getMetaData()).willReturn(rsmd);
given(generatedKeysResultSet2.getObject(1)).willReturn(456);
given(generatedKeysResultSet2.next()).willReturn(true, false);
given(this.preparedStatement.getGeneratedKeys()).willReturn(generatedKeysResultSet1, generatedKeysResultSet2);

int[] values = new int[]{100, 200};
BatchPreparedStatementSetter bpss = new BatchPreparedStatementSetter() {
@Override
public void setValues(PreparedStatement ps, int i) throws SQLException {
ps.setObject(i, values[i]);
}

@Override
public int getBatchSize() {
return 2;
}
};

KeyHolder keyHolder = new GeneratedKeyHolder();
this.template.batchUpdate(con -> con.prepareStatement(""), bpss, keyHolder);

assertThat(keyHolder.getKeyList()).containsExactly(
Collections.singletonMap("someId", 123),
Collections.singletonMap("someId", 456));
}

private void mockDatabaseMetaData(boolean supportsBatchUpdates) throws SQLException {
DatabaseMetaData databaseMetaData = mock(DatabaseMetaData.class);
Expand Down