diff --git a/nakadi-producer-spring-boot-starter/src/main/java/org/zalando/nakadiproducer/eventlog/impl/batcher/QueryStatementBatcher.java b/nakadi-producer-spring-boot-starter/src/main/java/org/zalando/nakadiproducer/eventlog/impl/batcher/QueryStatementBatcher.java new file mode 100644 index 0000000..b1f5b09 --- /dev/null +++ b/nakadi-producer-spring-boot-starter/src/main/java/org/zalando/nakadiproducer/eventlog/impl/batcher/QueryStatementBatcher.java @@ -0,0 +1,286 @@ +package org.zalando.nakadiproducer.eventlog.impl.batcher; + + +import org.springframework.jdbc.core.RowMapper; +import org.springframework.jdbc.core.namedparam.MapSqlParameterSource; +import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.Spliterator; +import java.util.Spliterators; +import java.util.stream.IntStream; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; + +import static java.util.stream.Collectors.joining; +import static java.util.stream.Collectors.toList; +import static org.springframework.jdbc.core.namedparam.SqlParameterSource.TYPE_UNKNOWN; +import static org.zalando.fahrschein.Preconditions.checkArgument; + +/** + * A helper class to simulate query batching for SQL statements which return data, + * i.e. SELECT or anything with a RETURNING clause. + * Inspired by Batching Select Statements in JDBC + * (by Jeanne Boyarski). + *

+ * The idea here is to prepare prepared statements returning result sets of a common row type, + * for several input batch sizes (e.g. 51, 11, 4, 1), and then split our actual input into these + * sizes (mostly the largest one, the smaller ones are then used for the rest). + * We then use the query*() methods to give the input to the DB, and compose the output + * together into one stream/list. + *

+ * Advantages: + * - It's less round trips than sending each query one-by-one. + * - Compared to building one statement for each input list, the DB only has a small + * number of different prepared statements to look at and optimize. + */ +public class QueryStatementBatcher { + + public static final String DEFAULT_TEMPLATE_PLACEHOLDER = "#"; + public static final String DEFAULT_TEMPLATE_SEPARATOR = ", "; + + private static final int[] DEFAULT_TEMPLATE_SIZES = {51, 13, 4, 1}; + + final private RowMapper resultRowMapper; + final List subTemplates; + + /** + * Sets up a QueryStatementBatcher for a specific set of statements, composed from prefix, repeated part, (default separator) and suffix. + * Sizes will be determined + * @param templatePrefix A prefix which will be prepended to the repeated part of the query. It can contain + * parameter placeholders as usual for NamedParameterJdbcTemplate. + * @param templateRepeated The part of the query string which will be repeated according to the number of parameter sets. + * The parameter placeholders in here should contain {@code "#"}. + * Occurrences of this in the generated queries will + * be separated by the default operator (a comma). + * @param templateSuffix A suffix which will be used after the repeated part. It can contain + * parameter placeholders as usual for NamedParameterJdbcTemplate. + * @param resultMapper A mapper which will be used to map the results of the queries (JDBC ResultSets) to whatever + * output format is desired. + */ + public QueryStatementBatcher(String templatePrefix, String templateRepeated, String templateSuffix, RowMapper resultMapper) { + this(templatePrefix, templateRepeated, DEFAULT_TEMPLATE_PLACEHOLDER, DEFAULT_TEMPLATE_SEPARATOR, templateSuffix, resultMapper, DEFAULT_TEMPLATE_SIZES); + } + + QueryStatementBatcher(String templatePrefix, String templateRepeated, String templatePlaceholder, String templateSeparator, String templateSuffix, RowMapper resultMapper) { + this(templatePrefix, templateRepeated, templatePlaceholder, templateSeparator, templateSuffix, resultMapper, DEFAULT_TEMPLATE_SIZES); + } + + QueryStatementBatcher(String templatePrefix, String templateRepeated, String templateSuffix, RowMapper resultMapper, + int... templateSizes) { + this(templatePrefix, templateRepeated, DEFAULT_TEMPLATE_PLACEHOLDER, DEFAULT_TEMPLATE_SEPARATOR, templateSuffix, resultMapper, templateSizes); + } + + /** + * Sets up a QueryStatementBatcher for a specific set of statements composed from prefix, repeated part, separator and suffix. + * @param templatePrefix A prefix which will be prepended to the repeated part of the query. It can contain + * parameter placeholders as usual for NamedParameterJdbcTemplate. + * @param templateRepeated The part of the query string which will be repeated according to the number of parameter sets. + * The parameter placeholders in here (if they vary between parameter sets) should contain the + * templatePlaceholder. + * @param templatePlaceholder This placeholder is to be used as part of the parameter names in the repeated templates. + * @param templateSeparator This separator will be used between the repeated parts of the query. + * @param templateSuffix A suffix which will be used after the repeated part. It can contain + * parameter placeholders as usual for NamedParameterJdbcTemplate. + * @param resultMapper A mapper which will be used to map the results of the queries (JDBC ResultSets) to whatever + * output format is desired. + * @param templateSizes A sequence of integers. Smallest one needs to be 1. + * This indicates the sizes (number of parameter sets used) to be used for the individual queries. + */ + QueryStatementBatcher(String templatePrefix, String templateRepeated, String templatePlaceholder, + String templateSeparator, String templateSuffix, RowMapper resultMapper, + int... templateSizes) { + this.resultRowMapper = resultMapper; + + sortDescending(templateSizes); + checkArgument(templateSizes[templateSizes.length-1] == 1, + "smallest template size is not 1!"); + this.subTemplates = IntStream.of(templateSizes) + .mapToObj(size -> new SubTemplate( + size, + composeTemplate(size, templatePrefix, templateRepeated, templatePlaceholder, + templateSeparator, templateSuffix), + templatePlaceholder)) + .collect(toList()); + } + + static String composeTemplate(int valueCount, String prefix, String repeated, String placeholder, String separator, String suffix) { + return IntStream.range(0, valueCount) + .mapToObj(i -> repeated.replace(placeholder, String.valueOf(i))) + .collect(joining(separator, prefix, suffix)); + } + + /** + * Queries the database for a set of parameter sources, in an optimized way. + * This version should be used if there are no parameters in the non-repeated part + * of the query tempate. + * @param database the DB connection in form of a spring NamedParameterJdbcTemplate. + * @param repeatedInputs A stream of repeated inputs. The names of the parameters here + * should contain the placeholder (by default "#"). + * @return A stream of results, one for each parameter source in the repeated input. + */ + public Stream queryForStream(NamedParameterJdbcTemplate database, + Stream repeatedInputs) { + return queryForStream(database, new MapSqlParameterSource(), repeatedInputs); + } + + /** + * Queries the database for a set of parameter sources, in an optimized way. + * This version should be used if there are parameters in the non-repeated part + * of the template. + * @param database the DB connection in form of a spring NamedParameterJdbcTemplate. + * @param commonArguments a parameter source for any template parameters in the + * non-repeated part of the query (or parameters in the + * repeated part which don't change between input). + * @param repeatedInputs A stream of repeated inputs. The names of the parameters here + * * should contain the placeholder (by default "#"). + * @return A stream of results, one for each parameter source in the repeated input. + */ + public Stream queryForStream(NamedParameterJdbcTemplate database, + MapSqlParameterSource commonArguments, + Stream repeatedInputs) { + return queryForStreamRecursive(database, commonArguments, repeatedInputs, 0); + } + + private Stream queryForStreamRecursive(NamedParameterJdbcTemplate database, + MapSqlParameterSource commonArguments, + Stream repeatedInputs, + int subTemplateIndex) { + SubTemplate firstSubTemplate = subTemplates.get(subTemplateIndex); + + Stream> chunkedStream = chunkStream(repeatedInputs, firstSubTemplate.inputCount); + return chunkedStream.flatMap(chunk -> { + if (chunk.size() == firstSubTemplate.inputCount) { + return firstSubTemplate.queryForStream(database, commonArguments, chunk, resultRowMapper); + } else { + return queryForStreamRecursive(database, commonArguments, chunk.stream(), subTemplateIndex + 1); + } + }); + } + + /** + * This nested class handles a single "batch size". + */ + static class SubTemplate { + final int inputCount; + final String expandedTemplate; + final String namePlaceholder; + + private SubTemplate(int inputCount, String expandedTemplate, String namePlaceholder) { + this.inputCount = inputCount; + this.expandedTemplate = expandedTemplate; + this.namePlaceholder = namePlaceholder; + } + + Stream queryForStream(NamedParameterJdbcTemplate database, + MapSqlParameterSource commonArguments, + List repeatedInputs, + RowMapper mapper) { + checkArgument(repeatedInputs.size() == inputCount, + "input size = %s != %s = inputCount", repeatedInputs.size(), inputCount); + + MapSqlParameterSource params = new MapSqlParameterSource(); + Stream.of(commonArguments.getParameterNames()) + .forEach(name -> copyTypeAndValue(commonArguments, name, params, name)); + IntStream.range(0, inputCount) + .forEach(index -> { + MapSqlParameterSource input = repeatedInputs.get(index); + String textIndex = String.valueOf(index); + Stream.of(input.getParameterNames()) + .forEach(name -> copyTypeAndValue(input, name, + params, name.replace(namePlaceholder, textIndex))); + }); + + return database.queryForStream(expandedTemplate, params, mapper); + } + + private static void copyTypeAndValue(MapSqlParameterSource source, String sourceName, + MapSqlParameterSource target, String targetName) { + target.addValue(targetName, source.getValue(sourceName)); + int type = source.getSqlType(sourceName); + if (type != TYPE_UNKNOWN) { + target.registerSqlType(targetName, type); + } + String typeName = source.getTypeName(sourceName); + if (typeName != null) { + target.registerTypeName(targetName, typeName); + } + } + + @Override + public String toString() { + return "SubTemplate{" + + "inputCount=" + inputCount + + ", expandedTemplate='" + expandedTemplate + '\'' + + ", namePlaceholder='" + namePlaceholder + '\'' + + '}'; + } + } + + /** + * Splits a stream into a stream of chunks of equal size, with possibly one final chunk of smaller size. + * This is a terminal operation on {@code input} (it's spliterator is requested), but its elements are + * only accessed when the return stream is processed. + * + * @param input a stream of elements to be chunked. + * @param chunkSize the size of each chunk. + * @param the type of elements in input. + * @return a new stream of lists. The returned lists can be modified, but that + * doesn't have any impact on the source of input. + * The stream is non-null, and preserves the ordered/immutable/concurrent/distinct + * properties of the input stream. + */ + static Stream> chunkStream(Stream input, int chunkSize) { + // inspired by https://stackoverflow.com/a/59164175/600500 + // I think there might be a way of optimizing this by actually using the + // spliterator for chunking, but that seems to become more complicated. + Spliterator inputSpliterator = input.spliterator(); + int characteristics = inputSpliterator.characteristics() + // these characteristics should reflect onto the chunked spliterator + & (Spliterator.ORDERED | Spliterator.IMMUTABLE | Spliterator.CONCURRENT | Spliterator.DISTINCT) + // the lists returned are always non-null (even if they might contain null elements) + | Spliterator.NONNULL + // not transferring characteristics: Spliterator.SORTED, Spliterator.SIZED, Spliterator.SUBSIZED + ; + return StreamSupport.stream(Spliterators.spliteratorUnknownSize(new Iterator<>() { + Iterator sourceIterator = Spliterators.iterator(inputSpliterator); + + @Override + public boolean hasNext() { + return sourceIterator.hasNext(); + } + + @Override + public List next() { + if (!sourceIterator.hasNext()) { + throw new NoSuchElementException("no more elements!"); + } + List result = new ArrayList(chunkSize); + for (int i = 0; i < chunkSize && sourceIterator.hasNext(); i++) { + result.add(sourceIterator.next()); + } + return result; + } + }, characteristics), false); + } + + private static void sortDescending(int[] templateSizes) { + // there is no Arrays.sort with comparator (or with flag to tell "descending"), so we sort it normally and then reverse it. + Arrays.sort(templateSizes); + reverse(templateSizes); + } + + private static void reverse(int[] array) { + // https://stackoverflow.com/a/3523066/600500 + for(int left = 0, right = array.length -1; left < right; left++, right --) { + int temp = array[left]; + array[left] = array[right]; + array[right] = temp; + } + } + +} diff --git a/nakadi-producer-spring-boot-starter/src/test/java/org/zalando/nakadiproducer/eventlog/impl/batcher/QueryStatementBatcherIT.java b/nakadi-producer-spring-boot-starter/src/test/java/org/zalando/nakadiproducer/eventlog/impl/batcher/QueryStatementBatcherIT.java new file mode 100644 index 0000000..406703b --- /dev/null +++ b/nakadi-producer-spring-boot-starter/src/test/java/org/zalando/nakadiproducer/eventlog/impl/batcher/QueryStatementBatcherIT.java @@ -0,0 +1,138 @@ +package org.zalando.nakadiproducer.eventlog.impl.batcher; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.jdbc.core.RowMapper; +import org.springframework.jdbc.core.namedparam.MapSqlParameterSource; +import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate; +import org.zalando.nakadiproducer.BaseMockedExternalCommunicationIT; + +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.stream.IntStream; + +import static java.util.stream.Collectors.toList; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; + +public class QueryStatementBatcherIT extends BaseMockedExternalCommunicationIT { + + private static final RowMapper ID_ROW_MAPPER = (row, n) -> row.getInt("id"); + @Autowired + private NamedParameterJdbcTemplate jdbcTemplate; + + @BeforeEach + public void setUpTable() { + jdbcTemplate.update("CREATE TABLE x (id SERIAL, a INT, b INT)", Map.of()); + } + @AfterEach + public void dropTable() { + jdbcTemplate.update("DROP TABLE x;", Map.of()); + } + + @Test + public void testStreamEvents() { + QueryStatementBatcher batcher = createInsertPairsReturningIdBatcher(); + MapSqlParameterSource commonArguments = new MapSqlParameterSource(); + int expectedCount = 31; + List repeatedInputs = IntStream.range(0, expectedCount) + .mapToObj(i -> new MapSqlParameterSource() + .addValue("a#", i) + .addValue("b#", 5 * i)) + .collect(toList()); + + List resultList = batcher.queryForStream( + jdbcTemplate, repeatedInputs.stream()) + .collect(toList()); + assertThat(resultList, hasSize(expectedCount)); + + List secondResultList = batcher.queryForStream( + jdbcTemplate, commonArguments, repeatedInputs.stream()) + .collect(toList()); + + assertThat(secondResultList, hasSize(expectedCount)); + assertThat(secondResultList.get(0), is(expectedCount+1)); + } + + private static QueryStatementBatcher createInsertPairsReturningIdBatcher() { + return new QueryStatementBatcher<>( + "INSERT INTO x (a, b) VALUES ", "(:a#, :b#)", " RETURNING id", + ID_ROW_MAPPER, + 51, 13, 4, 1); + } + + @Test + @Disabled("Running benchmarks takes too long.") + public void benchmarkWithBatcher() { + int totalCount = 5000; + List inputs = prepareInputs(totalCount); + Instant before = Instant.now(); + QueryStatementBatcher batcher = createInsertPairsReturningIdBatcher(); + List results = batcher.queryForStream(jdbcTemplate, inputs.stream()).collect(toList()); + Instant after = Instant.now(); + System.err.format("Inserting %s items took %s.\n", totalCount, Duration.between(before, after)); + System.out.println(results); + } + + private static List prepareInputs(int totalCount) { + return IntStream.range(0, totalCount) + .mapToObj(i -> new MapSqlParameterSource() + .addValue("a#", 3 * i) + .addValue("b#", 5 * i)) + .collect(toList()); + } + + @Test + @Disabled("Running benchmarks takes too long.") + public void benchmarkWithoutBatcherSerial() { + int totalCount = 5000; + List inputs = prepareInputs(totalCount); + Instant before = Instant.now(); + List results = inputs.stream() + .map(source -> jdbcTemplate.queryForObject( + "INSERT INTO x (a, b) VALUES (:a#, :b#) RETURNING id", + source, ID_ROW_MAPPER)) + .collect(toList()); + Instant after = Instant.now(); + System.err.format("Inserting %s items took %s.\n", totalCount, Duration.between(before, after)); + System.out.println(results); + } + + @Test + @Disabled("Running benchmarks takes too long.") + public void benchmarkWithoutBatcherParallel() { + int totalCount = 5000; + List inputs = prepareInputs(totalCount); + Instant before = Instant.now(); + List results = inputs.parallelStream() + .map(source -> jdbcTemplate.queryForObject( + "INSERT INTO x (a, b) VALUES (:a#, :b#) RETURNING id", + source, ID_ROW_MAPPER)) + .collect(toList()); + Instant after = Instant.now(); + System.err.format("Inserting %s items took %s.\n", totalCount, Duration.between(before, after)); + System.out.println(results); + } + + @Test + @Disabled("Running benchmarks takes too long.") + public void benchmarkBatchWithoutReturn() { + int totalCount = 5000; + List inputs = prepareInputs(totalCount); + MapSqlParameterSource[] inputArray = inputs.toArray(new MapSqlParameterSource[0]); + Instant before = Instant.now(); + int[] results = jdbcTemplate.batchUpdate( + "INSERT INTO x (a, b) VALUES (:a#, :b#)", + inputArray); + Instant after = Instant.now(); + System.err.format("Inserting %s items took %s.\n", totalCount, Duration.between(before, after)); + System.out.println(Arrays.toString(results)); + } +} diff --git a/nakadi-producer-spring-boot-starter/src/test/java/org/zalando/nakadiproducer/eventlog/impl/batcher/QueryStatementBatcherTest.java b/nakadi-producer-spring-boot-starter/src/test/java/org/zalando/nakadiproducer/eventlog/impl/batcher/QueryStatementBatcherTest.java new file mode 100644 index 0000000..9c63dcb --- /dev/null +++ b/nakadi-producer-spring-boot-starter/src/test/java/org/zalando/nakadiproducer/eventlog/impl/batcher/QueryStatementBatcherTest.java @@ -0,0 +1,51 @@ +package org.zalando.nakadiproducer.eventlog.impl.batcher; + +import org.junit.jupiter.api.Test; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; + +public class QueryStatementBatcherTest +{ + @Test + public void testComposeTemplateInsert() { + + String prefix = "INSERT INTO x (a, b) VALUES "; + String repeated = "(:a#, :b#)"; + String placeholder = "#"; + String separator = ", "; + String suffix = " RETURNING id"; + + assertThat(QueryStatementBatcher.composeTemplate(1, prefix, repeated, placeholder, separator, suffix), + is("INSERT INTO x (a, b) VALUES (:a0, :b0) RETURNING id")); + assertThat(QueryStatementBatcher.composeTemplate(2, prefix, repeated, placeholder, separator, suffix), + is("INSERT INTO x (a, b) VALUES (:a0, :b0), (:a1, :b1) RETURNING id")); + } + + @Test + public void testComposeTemplateSelectWhere() { + String prefix = "SELECT a, b FROM x WHERE id IN ("; + String repeated = ":id#"; + String separator = ", "; + String placeholder = "#"; + String suffix = ")"; + + assertThat(QueryStatementBatcher.composeTemplate(1, prefix, repeated, placeholder, separator, suffix), + is("SELECT a, b FROM x WHERE id IN (:id0)")); + assertThat(QueryStatementBatcher.composeTemplate(2, prefix, repeated, placeholder, separator, suffix), + is("SELECT a, b FROM x WHERE id IN (:id0, :id1)")); + } + + @Test + public void testCreateSubTemplates() { + QueryStatementBatcher batcher = new QueryStatementBatcher<>( + "SELECT a, b FROM x WHERE id IN (", ":id#", ")", + (row, n) -> null, + 21, 6, 1); + assertThat(batcher.subTemplates, hasSize(3)); + assertThat(batcher.subTemplates.get(0).expandedTemplate, + is("SELECT a, b FROM x WHERE id IN (:id0, :id1, :id2, :id3, :id4, :id5, :id6, :id7," + + " :id8, :id9, :id10, :id11, :id12, :id13, :id14, :id15, :id16, :id17, :id18, :id19, :id20)") ); + } +}