diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcMetadataConfig.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcMetadataConfig.java index 963f5ea624ff..fc94632d2af9 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcMetadataConfig.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcMetadataConfig.java @@ -17,10 +17,13 @@ import io.airlift.configuration.ConfigDescription; import io.airlift.configuration.LegacyConfig; +import javax.validation.constraints.Max; import javax.validation.constraints.Min; public class JdbcMetadataConfig { + static final int MAX_ALLOWED_INSERT_BATCH_SIZE = 1_000_000; + private boolean allowDropTable; /* * Join pushdown is disabled by default as this is the safer option. @@ -40,6 +43,8 @@ public class JdbcMetadataConfig // between performance and pushdown capabilities private int domainCompactionThreshold = 32; + private int insertBatchSize = 1000; + public boolean isAllowDropTable() { return allowDropTable; @@ -107,4 +112,19 @@ public JdbcMetadataConfig setDomainCompactionThreshold(int domainCompactionThres this.domainCompactionThreshold = domainCompactionThreshold; return this; } + + @Min(1) + @Max(MAX_ALLOWED_INSERT_BATCH_SIZE) + public int getInsertBatchSize() + { + return insertBatchSize; + } + + @Config("insert.batch-size") + @ConfigDescription("Maximum number of rows to insert in a single batch") + public JdbcMetadataConfig setInsertBatchSize(int insertBatchSize) + { + this.insertBatchSize = insertBatchSize; + return this; + } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcMetadataSessionProperties.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcMetadataSessionProperties.java index 983022eaac03..ff67b688b49a 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcMetadataSessionProperties.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcMetadataSessionProperties.java @@ -24,6 +24,7 @@ import java.util.List; import java.util.Optional; +import static io.trino.plugin.jdbc.JdbcMetadataConfig.MAX_ALLOWED_INSERT_BATCH_SIZE; import static io.trino.spi.StandardErrorCode.INVALID_SESSION_PROPERTY; import static io.trino.spi.session.PropertyMetadata.booleanProperty; import static io.trino.spi.session.PropertyMetadata.integerProperty; @@ -36,6 +37,7 @@ public class JdbcMetadataSessionProperties public static final String AGGREGATION_PUSHDOWN_ENABLED = "aggregation_pushdown_enabled"; public static final String TOPN_PUSHDOWN_ENABLED = "topn_pushdown_enabled"; public static final String DOMAIN_COMPACTION_THRESHOLD = "domain_compaction_threshold"; + public static final String INSERT_BATCH_SIZE = "insert_batch_size"; private final List> properties; @@ -65,6 +67,12 @@ public JdbcMetadataSessionProperties(JdbcMetadataConfig jdbcMetadataConfig, @Max "Enable TopN pushdown", jdbcMetadataConfig.isTopNPushdownEnabled(), false)) + .add(integerProperty( + INSERT_BATCH_SIZE, + "Insert batch size", + jdbcMetadataConfig.getInsertBatchSize(), + value -> validateInsertBatchSize(value, MAX_ALLOWED_INSERT_BATCH_SIZE), + false)) .build(); } @@ -94,6 +102,11 @@ public static int getDomainCompactionThreshold(ConnectorSession session) return session.getProperty(DOMAIN_COMPACTION_THRESHOLD, Integer.class); } + public static int getInsertBatchSize(ConnectorSession session) + { + return session.getProperty(INSERT_BATCH_SIZE, Integer.class); + } + private static void validateDomainCompactionThreshold(int domainCompactionThreshold, Optional maxDomainCompactionThreshold) { if (domainCompactionThreshold < 1) { @@ -106,4 +119,14 @@ private static void validateDomainCompactionThreshold(int domainCompactionThresh } }); } + + private static void validateInsertBatchSize(int maxBatchSize, int maxAllowedBatchSize) + { + if (maxBatchSize < 1) { + throw new TrinoException(INVALID_SESSION_PROPERTY, format("%s must be greater than 0: %s", INSERT_BATCH_SIZE, maxBatchSize)); + } + if (maxBatchSize > maxAllowedBatchSize) { + throw new TrinoException(INVALID_SESSION_PROPERTY, format("%s cannot exceed %s: %s", INSERT_BATCH_SIZE, maxAllowedBatchSize, maxBatchSize)); + } + } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcPageSink.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcPageSink.java index 04787aa1009e..ad3830a7a81e 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcPageSink.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcPageSink.java @@ -35,6 +35,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_ERROR; import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_NON_TRANSIENT_ERROR; +import static io.trino.plugin.jdbc.JdbcMetadataSessionProperties.getInsertBatchSize; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static java.util.concurrent.CompletableFuture.completedFuture; @@ -46,6 +47,7 @@ public class JdbcPageSink private final List columnTypes; private final List columnWriters; + private final int maxBatchSize; private int batchSize; public JdbcPageSink(ConnectorSession session, JdbcOutputTableHandle handle, JdbcClient jdbcClient) @@ -92,6 +94,8 @@ public JdbcPageSink(ConnectorSession session, JdbcOutputTableHandle handle, Jdbc closeWithSuppression(connection, e); throw new TrinoException(JDBC_ERROR, e); } + + this.maxBatchSize = getInsertBatchSize(session); } @Override @@ -106,7 +110,7 @@ public CompletableFuture appendPage(Page page) statement.addBatch(); batchSize++; - if (batchSize >= 1000) { + if (batchSize >= maxBatchSize) { statement.executeBatch(); batchSize = 0; } diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java index b8bbb45b380d..c8b0adf36a4c 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java @@ -39,11 +39,15 @@ import org.intellij.lang.annotations.Language; import org.testng.SkipException; import org.testng.annotations.AfterClass; +import org.testng.annotations.DataProvider; import org.testng.annotations.Test; +import java.util.ArrayList; import java.util.List; +import java.util.UUID; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; +import java.util.concurrent.ThreadLocalRandom; import java.util.stream.Stream; import static com.google.common.base.Preconditions.checkState; @@ -1278,4 +1282,50 @@ public void testDeleteWithVarcharPredicate() { throw new SkipException("This is implemented by testDeleteWithVarcharEqualityPredicate"); } + + @Test(dataProvider = "testInsertBatchSizeSessionProperty") + public void testInsertBatchSizeSessionProperty(Integer batchSize, Integer numberOfRows) + { + if (!hasBehavior(SUPPORTS_CREATE_TABLE)) { + throw new SkipException("CREATE TABLE is required for insert_batch_size test but is not supported"); + } + Session session = Session.builder(getSession()) + .setCatalogSessionProperty(getSession().getCatalog().orElseThrow(), "insert_batch_size", batchSize.toString()) + .build(); + + try (TestTable table = new TestTable( + getQueryRunner()::execute, + "test_insert_batch_size", + "(a varchar(36), b bigint)")) { + String values = String.join(",", makeValuesForInsertBatchSizeSessionPropertyTest(numberOfRows)); + assertUpdate(session, "INSERT INTO " + table.getName() + " (a, b) VALUES " + values, numberOfRows); + assertQuery("SELECT COUNT(*) FROM " + table.getName(), format("VALUES %d", numberOfRows)); + } + } + + private static List makeValuesForInsertBatchSizeSessionPropertyTest(int numberOfRows) + { + List result = new ArrayList<>(numberOfRows); + for (int i = 0; i < numberOfRows; i++) { + result.add(format("('%s', %d)", UUID.randomUUID(), ThreadLocalRandom.current().nextLong())); + } + return result; + } + + @DataProvider(name = "testInsertBatchSizeSessionProperty") + public static Object[][] batchSizeAndNumberOfRowsForInsertBatchSizePropertyTest() + { + return new Object[][] { + {100, 64}, + {100, 100}, + {100, 512}, + {100, 1000}, + {1000, 100}, + {1000, 1000}, + {1000, 5000}, + {10000, 1000}, + {10000, 5000}, + {10000, 15000}, + }; + } } diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcMetadataConfig.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcMetadataConfig.java index 693136be110d..28b569979fa8 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcMetadataConfig.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcMetadataConfig.java @@ -14,6 +14,7 @@ package io.trino.plugin.jdbc; import com.google.common.collect.ImmutableMap; +import io.airlift.configuration.ConfigurationFactory; import org.testng.annotations.Test; import java.util.Map; @@ -21,6 +22,8 @@ import static io.airlift.configuration.testing.ConfigAssertions.assertFullMapping; import static io.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; import static io.airlift.configuration.testing.ConfigAssertions.recordDefaults; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; public class TestJdbcMetadataConfig { @@ -32,7 +35,8 @@ public void testDefaults() .setJoinPushdownEnabled(false) .setAggregationPushdownEnabled(true) .setTopNPushdownEnabled(true) - .setDomainCompactionThreshold(32)); + .setDomainCompactionThreshold(32) + .setInsertBatchSize(1000)); } @Test @@ -44,6 +48,7 @@ public void testExplicitPropertyMappings() .put("aggregation-pushdown.enabled", "false") .put("domain-compaction-threshold", "42") .put("topn-pushdown.enabled", "false") + .put("insert.batch-size", "24") .build(); JdbcMetadataConfig expected = new JdbcMetadataConfig() @@ -51,8 +56,30 @@ public void testExplicitPropertyMappings() .setJoinPushdownEnabled(true) .setAggregationPushdownEnabled(false) .setTopNPushdownEnabled(false) - .setDomainCompactionThreshold(42); + .setDomainCompactionThreshold(42) + .setInsertBatchSize(24); assertFullMapping(properties, expected); } + + @Test + public void testInsertBatchSizeValidation() + { + assertThatThrownBy(() -> makeConfig(ImmutableMap.of("insert.batch-size", "-42"))) + .hasMessageContaining("insert.batch-size: must be greater than or equal to 1"); + + assertThatThrownBy(() -> makeConfig(ImmutableMap.of("insert.batch-size", "0"))) + .hasMessageContaining("insert.batch-size: must be greater than or equal to 1"); + + assertThatCode(() -> makeConfig(ImmutableMap.of("insert.batch-size", "1"))) + .doesNotThrowAnyException(); + + assertThatCode(() -> makeConfig(ImmutableMap.of("insert.batch-size", "42"))) + .doesNotThrowAnyException(); + } + + private static JdbcMetadataConfig makeConfig(Map props) + { + return new ConfigurationFactory(props).build(JdbcMetadataConfig.class); + } }