From 8820f5a9d2dd9067661617f859cf611ced826ae9 Mon Sep 17 00:00:00 2001 From: Martin Traverso Date: Fri, 4 Oct 2024 11:33:50 -0700 Subject: [PATCH] Simplify planning of MERGE Currently, the planner inserts a projection of the following shape to assemble the merged row: merge_row := ( CASE WHEN ... THEN ROW(..., $not((present IS NULL)), , 0) WHEN ... THEN ROW(..., $not((present IS NULL)), , 1) ... ELSE ROW(, $not((present IS NULL)), -1, -1) END) This change replaces the ELSE branch to return a single null instead of a synthetic value with nulls. By reducing the size of the projection, it allows for wider tables to be used with MERGE. --- ...hangeOnlyUpdatedColumnsMergeProcessor.java | 10 +---- .../DeleteAndInsertMergeProcessor.java | 26 +++++++------ .../operator/MergeRowChangeProcessor.java | 2 - .../io/trino/sql/planner/QueryPlanner.java | 22 +++++------ .../TestDeleteAndInsertMergeProcessor.java | 37 +++++++++++-------- 5 files changed, 48 insertions(+), 49 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/operator/ChangeOnlyUpdatedColumnsMergeProcessor.java b/core/trino-main/src/main/java/io/trino/operator/ChangeOnlyUpdatedColumnsMergeProcessor.java index 4d93857d78a5..4ce8d5ca5ec6 100644 --- a/core/trino-main/src/main/java/io/trino/operator/ChangeOnlyUpdatedColumnsMergeProcessor.java +++ b/core/trino-main/src/main/java/io/trino/operator/ChangeOnlyUpdatedColumnsMergeProcessor.java @@ -66,12 +66,6 @@ public Page transformPage(Page inputPage) checkArgument(positionCount > 0, "positionCount should be > 0, but is %s", positionCount); Block mergeRow = inputPage.getBlock(mergeRowChannel).getLoadedBlock(); - if (mergeRow.mayHaveNull()) { - for (int position = 0; position < positionCount; position++) { - checkArgument(!mergeRow.isNull(position), "The mergeRow may not have null rows"); - } - } - List fields = getRowFieldsFromBlock(mergeRow); List builder = new ArrayList<>(dataColumnChannels.size() + 3); for (int channel : dataColumnChannels) { @@ -86,7 +80,7 @@ public Page transformPage(Page inputPage) int defaultCaseCount = 0; for (int position = 0; position < positionCount; position++) { - if (TINYINT.getByte(operationChannelBlock, position) == DEFAULT_CASE_OPERATION_NUMBER) { + if (mergeRow.isNull(position)) { defaultCaseCount++; } } @@ -97,7 +91,7 @@ public Page transformPage(Page inputPage) int usedCases = 0; int[] positions = new int[positionCount - defaultCaseCount]; for (int position = 0; position < positionCount; position++) { - if (TINYINT.getByte(operationChannelBlock, position) != DEFAULT_CASE_OPERATION_NUMBER) { + if (!mergeRow.isNull(position)) { positions[usedCases] = position; usedCases++; } diff --git a/core/trino-main/src/main/java/io/trino/operator/DeleteAndInsertMergeProcessor.java b/core/trino-main/src/main/java/io/trino/operator/DeleteAndInsertMergeProcessor.java index bfcc2fc16808..1a76b33a78a7 100644 --- a/core/trino-main/src/main/java/io/trino/operator/DeleteAndInsertMergeProcessor.java +++ b/core/trino-main/src/main/java/io/trino/operator/DeleteAndInsertMergeProcessor.java @@ -99,22 +99,24 @@ public Page transformPage(Page inputPage) int originalPositionCount = inputPage.getPositionCount(); checkArgument(originalPositionCount > 0, "originalPositionCount should be > 0, but is %s", originalPositionCount); - List fields = getRowFieldsFromBlock(inputPage.getBlock(mergeRowChannel)); + Block mergeRow = inputPage.getBlock(mergeRowChannel); + List fields = getRowFieldsFromBlock(mergeRow); Block operationChannelBlock = fields.get(fields.size() - 2); int updatePositions = 0; int insertPositions = 0; int deletePositions = 0; for (int position = 0; position < originalPositionCount; position++) { - byte operation = TINYINT.getByte(operationChannelBlock, position); - switch (operation) { - case DEFAULT_CASE_OPERATION_NUMBER -> { /* ignored */ } - case INSERT_OPERATION_NUMBER -> insertPositions++; - case DELETE_OPERATION_NUMBER -> deletePositions++; - case UPDATE_OPERATION_NUMBER -> updatePositions++; - // This class will create such rows, they are not expected on input - case UPDATE_INSERT_OPERATION_NUMBER, UPDATE_DELETE_OPERATION_NUMBER -> throw new IllegalArgumentException("Unexpected operator number: " + operation); - default -> throw new IllegalArgumentException("Unknown operator number: " + operation); + if (!mergeRow.isNull(position)) { + byte operation = TINYINT.getByte(operationChannelBlock, position); + switch (operation) { + case INSERT_OPERATION_NUMBER -> insertPositions++; + case DELETE_OPERATION_NUMBER -> deletePositions++; + case UPDATE_OPERATION_NUMBER -> updatePositions++; + // This class will create such rows, they are not expected on input + case UPDATE_INSERT_OPERATION_NUMBER, UPDATE_DELETE_OPERATION_NUMBER -> throw new IllegalArgumentException("Unexpected operator number: " + operation); + default -> throw new IllegalArgumentException("Unknown operator number: " + operation); + } } } @@ -128,8 +130,8 @@ public Page transformPage(Page inputPage) PageBuilder pageBuilder = new PageBuilder(totalPositions, pageTypes); for (int position = 0; position < originalPositionCount; position++) { - byte operation = TINYINT.getByte(operationChannelBlock, position); - if (operation != DEFAULT_CASE_OPERATION_NUMBER) { + if (!mergeRow.isNull(position)) { + byte operation = TINYINT.getByte(operationChannelBlock, position); // Delete and Update because both create a delete row if (operation == DELETE_OPERATION_NUMBER || operation == UPDATE_OPERATION_NUMBER) { addDeleteRow(pageBuilder, inputPage, position, operation != DELETE_OPERATION_NUMBER); diff --git a/core/trino-main/src/main/java/io/trino/operator/MergeRowChangeProcessor.java b/core/trino-main/src/main/java/io/trino/operator/MergeRowChangeProcessor.java index 82720c08b8dd..67f0df051bb7 100644 --- a/core/trino-main/src/main/java/io/trino/operator/MergeRowChangeProcessor.java +++ b/core/trino-main/src/main/java/io/trino/operator/MergeRowChangeProcessor.java @@ -18,8 +18,6 @@ public interface MergeRowChangeProcessor { - int DEFAULT_CASE_OPERATION_NUMBER = -1; - /** * Transform a page generated by an SQL MERGE operation into page of data columns and * operations. The SQL MERGE input page consists of the following: diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java index ce4d98340717..8278515b22a7 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java @@ -31,6 +31,7 @@ import io.trino.spi.connector.SortOrder; import io.trino.spi.type.DecimalType; import io.trino.spi.type.Int128; +import io.trino.spi.type.RowType; import io.trino.spi.type.Type; import io.trino.sql.NodeUtils; import io.trino.sql.PlannerContext; @@ -861,17 +862,16 @@ public MergeWriterNode plan(Merge merge) } } - // Build the "else" clause for the SearchedCaseExpression - ImmutableList.Builder rowBuilder = ImmutableList.builder(); - dataColumnSchemas.forEach(columnSchema -> - rowBuilder.add(new Constant(columnSchema.getType(), null))); - rowBuilder.add(not(metadata, new IsNull(presentColumn.toSymbolReference()))); - // The operation number - rowBuilder.add(new Constant(TINYINT, -1L)); - // The case number - rowBuilder.add(new Constant(INTEGER, -1L)); - - Case caseExpression = new Case(whenClauses.build(), new Row(rowBuilder.build())); + Case caseExpression = new Case( + whenClauses.build(), + new Constant( + RowType.anonymous(ImmutableList.builder() + .addAll(dataColumnSchemas.stream().map(ColumnSchema::getType).collect(toImmutableList())) + .add(BOOLEAN) + .add(TINYINT) + .add(INTEGER) + .build()), + null)); Symbol mergeRowSymbol = symbolAllocator.newSymbol("merge_row", mergeAnalysis.getMergeRowType()); Symbol caseNumberSymbol = symbolAllocator.newSymbol("case_number", INTEGER); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestDeleteAndInsertMergeProcessor.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestDeleteAndInsertMergeProcessor.java index b1e84ceea2a1..ce28fa59dcc2 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestDeleteAndInsertMergeProcessor.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestDeleteAndInsertMergeProcessor.java @@ -33,7 +33,6 @@ import java.util.List; import java.util.Optional; -import static io.trino.operator.MergeRowChangeProcessor.DEFAULT_CASE_OPERATION_NUMBER; import static io.trino.spi.connector.ConnectorMergeSink.DELETE_OPERATION_NUMBER; import static io.trino.spi.connector.ConnectorMergeSink.INSERT_OPERATION_NUMBER; import static io.trino.spi.connector.ConnectorMergeSink.UPDATE_OPERATION_NUMBER; @@ -58,20 +57,21 @@ public void testSimpleDeletedRowMerge() // THEN DELETE // expected: ('Dave', 11, 'Darbyshire') DeleteAndInsertMergeProcessor processor = makeMergeProcessor(); - Page inputPage = makePageFromBlocks( - 2, - Optional.empty(), - new Block[] { - makeLongArrayBlock(1, 1), // TransactionId - makeLongArrayBlock(1, 0), // rowId - makeIntArrayBlock(536870912, 536870912)}, // bucket - new Block[] { - makeVarcharArrayBlock("", "Dave"), // customer - makeIntArrayBlock(0, 11), // purchases - makeVarcharArrayBlock("", "Devon"), // address - makeByteArrayBlock(1, 1), // "present" boolean - makeByteArrayBlock(DEFAULT_CASE_OPERATION_NUMBER, DELETE_OPERATION_NUMBER), - makeIntArrayBlock(-1, 0)}); + Block[] rowIdBlocks = new Block[] { + makeLongArrayBlock(1, 1), // TransactionId + makeLongArrayBlock(1, 0), // rowId + makeIntArrayBlock(536870912, 536870912)}; // bucket + Block[] mergeCaseBlocks = new Block[] { + makeVarcharArrayBlock(null, "Dave"), // customer + new IntArrayBlock(2, Optional.of(new boolean[] {true, false}), new int[] {0, 11}), // purchases + makeVarcharArrayBlock(null, "Devon"), // address + new ByteArrayBlock(2, Optional.of(new boolean[] {true, false}), new byte[] {0, 1}), // "present" boolean + new ByteArrayBlock(2, Optional.of(new boolean[] {true, false}), new byte[] {0, DELETE_OPERATION_NUMBER}), // "present" boolean + new IntArrayBlock(2, Optional.of(new boolean[] {true, false}), new int[] {0, 0}) + }; + Page inputPage = new Page( + RowBlock.fromNotNullSuppressedFieldBlocks(2, Optional.empty(), rowIdBlocks), + RowBlock.fromNotNullSuppressedFieldBlocks(2, Optional.of(new boolean[] {true, false}), mergeCaseBlocks)); Page outputPage = processor.transformPage(inputPage); assertThat(outputPage.getPositionCount()).isEqualTo(1); @@ -215,7 +215,12 @@ private Block makeVarcharArrayBlock(String... elements) { BlockBuilder builder = VARCHAR.createBlockBuilder(new PageBuilderStatus().createBlockBuilderStatus(), elements.length); for (String element : elements) { - VARCHAR.writeSlice(builder, Slices.utf8Slice(element)); + if (element == null) { + builder.appendNull(); + } + else { + VARCHAR.writeSlice(builder, Slices.utf8Slice(element)); + } } return builder.build(); }