From 13aa093761a45accf617d15cf8cce6afaa18e8e5 Mon Sep 17 00:00:00 2001 From: Marius Grama Date: Tue, 18 Jul 2023 14:26:06 +0200 Subject: [PATCH] Use `putIdentity` method to reduce code duplication --- .../main/java/io/trino/sql/planner/QueryPlanner.java | 12 ++++++------ .../java/io/trino/sql/planner/RelationPlanner.java | 4 ++-- .../rule/PushAggregationThroughOuterJoin.java | 2 +- .../rule/PushProjectionThroughExchange.java | 6 +++--- .../optimizations/HashGenerationOptimizer.java | 2 +- .../rule/TestPushProjectionThroughExchange.java | 4 ++-- .../iterative/rule/TestPushTopNThroughProject.java | 2 +- 7 files changed, 16 insertions(+), 16 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java index 29589b18b00f..4d98a1f81bb9 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java @@ -547,7 +547,7 @@ public PlanNode plan(Delete node) Symbol symbol = relationPlan.getFieldMappings().get(fieldIndex); columnSymbolsBuilder.add(symbol); if (mergeAnalysis.getRedistributionColumnHandles().contains(columnHandle)) { - assignmentsBuilder.put(symbol, symbol.toSymbolReference()); + assignmentsBuilder.putIdentity(symbol); } else { assignmentsBuilder.put(symbol, new NullLiteral()); @@ -720,11 +720,11 @@ public PlanNode plan(Update node) for (ColumnHandle column : mergeAnalysis.getRedistributionColumnHandles()) { int fieldIndex = requireNonNull(mergeAnalysis.getColumnHandleFieldNumbers().get(column), "Could not find fieldIndex for redistribution column"); Symbol symbol = relationPlan.getFieldMappings().get(fieldIndex); - projectionAssignmentsBuilder.put(symbol, symbol.toSymbolReference()); + projectionAssignmentsBuilder.putIdentity(symbol); } // Add the rest of the page columns: rowId, merge row, case number and is_distinct - projectionAssignmentsBuilder.put(rowIdSymbol, rowIdSymbol.toSymbolReference()); + projectionAssignmentsBuilder.putIdentity(rowIdSymbol); projectionAssignmentsBuilder.put(mergeRowSymbol, mergeRow); projectionAssignmentsBuilder.put(caseNumberSymbol, new GenericLiteral("INTEGER", "0")); projectionAssignmentsBuilder.put(isDistinctSymbol, TRUE_LITERAL); @@ -858,10 +858,10 @@ public MergeWriterNode plan(Merge merge) for (ColumnHandle column : mergeAnalysis.getRedistributionColumnHandles()) { int fieldIndex = requireNonNull(mergeAnalysis.getColumnHandleFieldNumbers().get(column), "Could not find fieldIndex for redistribution column"); Symbol symbol = planWithPresentColumn.getFieldMappings().get(fieldIndex); - projectionAssignmentsBuilder.put(symbol, symbol.toSymbolReference()); + projectionAssignmentsBuilder.putIdentity(symbol); } - projectionAssignmentsBuilder.put(uniqueIdSymbol, uniqueIdSymbol.toSymbolReference()); - projectionAssignmentsBuilder.put(rowIdSymbol, rowIdSymbol.toSymbolReference()); + projectionAssignmentsBuilder.putIdentity(uniqueIdSymbol); + projectionAssignmentsBuilder.putIdentity(rowIdSymbol); projectionAssignmentsBuilder.put(mergeRowSymbol, caseExpression); ProjectNode subPlanProject = new ProjectNode( diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java index ae1527599426..3882d037d586 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java @@ -975,13 +975,13 @@ If casts are redundant (due to column type and common type being equal), for (int field : joinAnalysis.getOtherLeftFields()) { Symbol symbol = left.getFieldMappings().get(field); outputs.add(symbol); - assignments.put(symbol, symbol.toSymbolReference()); + assignments.putIdentity(symbol); } for (int field : joinAnalysis.getOtherRightFields()) { Symbol symbol = right.getFieldMappings().get(field); outputs.add(symbol); - assignments.put(symbol, symbol.toSymbolReference()); + assignments.putIdentity(symbol); } return new RelationPlan( diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java index af67662b549f..46cee94055b7 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java @@ -264,7 +264,7 @@ private Optional coalesceWithNullAggregation(AggregationNode aggregati assignmentsBuilder.put(symbol, new CoalesceExpression(symbol.toSymbolReference(), sourceAggregationToOverNullMapping.get(symbol).toSymbolReference())); } else { - assignmentsBuilder.put(symbol, symbol.toSymbolReference()); + assignmentsBuilder.putIdentity(symbol); } } return Optional.of(new ProjectNode(idAllocator.getNextId(), crossJoin, assignmentsBuilder.build())); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionThroughExchange.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionThroughExchange.java index 57d4ff82ea77..0296605344c7 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionThroughExchange.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionThroughExchange.java @@ -97,7 +97,7 @@ public Result apply(ProjectNode project, Captures captures, Context context) partitioningColumns.stream() .map(outputToInputMap::get) .forEach(inputSymbol -> { - projections.put(inputSymbol, inputSymbol.toSymbolReference()); + projections.putIdentity(inputSymbol); inputs.add(inputSymbol); }); @@ -105,7 +105,7 @@ public Result apply(ProjectNode project, Captures captures, Context context) exchange.getPartitioningScheme().getHashColumn() .map(outputToInputMap::get) .ifPresent(inputSymbol -> { - projections.put(inputSymbol, inputSymbol.toSymbolReference()); + projections.putIdentity(inputSymbol); inputs.add(inputSymbol); }); @@ -116,7 +116,7 @@ public Result apply(ProjectNode project, Captures captures, Context context) .filter(symbol -> !partitioningColumns.contains(symbol)) .map(outputToInputMap::get) .forEach(inputSymbol -> { - projections.put(inputSymbol, inputSymbol.toSymbolReference()); + projections.putIdentity(inputSymbol); inputs.add(inputSymbol); }); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/HashGenerationOptimizer.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/HashGenerationOptimizer.java index 365d3ea97243..e60459a9898d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/HashGenerationOptimizer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/HashGenerationOptimizer.java @@ -759,7 +759,7 @@ private PlanWithProperties enforce(PlanWithProperties planWithProperties, HashCo for (Symbol symbol : planWithProperties.getNode().getOutputSymbols()) { HashComputation partitionSymbols = resultHashSymbols.get(symbol); if (partitionSymbols == null || requiredHashes.getHashes().contains(partitionSymbols)) { - assignments.put(symbol, symbol.toSymbolReference()); + assignments.putIdentity(symbol); if (partitionSymbols != null) { outputHashSymbols.put(partitionSymbols, symbol); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughExchange.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughExchange.java index 8f57609cd50f..f521a55e7dea 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughExchange.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushProjectionThroughExchange.java @@ -61,8 +61,8 @@ public void testDoesNotFireNarrowingProjection() return p.project( Assignments.builder() - .put(a, a.toSymbolReference()) - .put(b, b.toSymbolReference()) + .putIdentity(a) + .putIdentity(b) .build(), p.exchange(e -> e .addSource(p.values(a, b, c)) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushTopNThroughProject.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushTopNThroughProject.java index 05214b9afb7f..55bc03ca5207 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushTopNThroughProject.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushTopNThroughProject.java @@ -173,7 +173,7 @@ public void testPushTopNThroughOverlappingDereferences() Assignments.builder() .put(p.symbol("b"), new SubscriptExpression(a.toSymbolReference(), new LongLiteral("1"))) .put(p.symbol("c", rowType), a.toSymbolReference()) - .put(d, d.toSymbolReference()) + .putIdentity(d) .build(), p.values(a, d))); })