Skip to content

Commit

Permalink
Pass partition channel types directly to LocalExchange
Browse files Browse the repository at this point in the history
  • Loading branch information
gaurav8297 authored and sopel39 committed Oct 27, 2022
1 parent 335d70b commit c7e0ecc
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public LocalExchange(
int defaultConcurrency,
PartitioningHandle partitioning,
List<Integer> partitionChannels,
List<Type> types,
List<Type> partitionChannelTypes,
Optional<Integer> partitionHashChannel,
DataSize maxBufferedBytes,
BlockTypeOperators blockTypeOperators,
Expand All @@ -106,10 +106,6 @@ public LocalExchange(
.map(buffer -> (Consumer<PageReference>) buffer::addPage)
.collect(toImmutableList());

List<Type> partitionChannelTypes = partitionChannels.stream()
.map(types::get)
.collect(toImmutableList());

this.memoryManager = new LocalExchangeMemoryManager(maxBufferedBytes.toBytes());
if (partitioning.equals(SINGLE_DISTRIBUTION)) {
exchangerSupplier = () -> new BroadcastExchanger(buffers, memoryManager);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3529,7 +3529,7 @@ private PhysicalOperation createLocalMerge(ExchangeNode node, LocalExecutionPlan
operatorsCount,
node.getPartitioningScheme().getPartitioning().getHandle(),
ImmutableList.of(),
types,
ImmutableList.of(),
Optional.empty(),
maxLocalExchangeBufferSize,
blockTypeOperators,
Expand Down Expand Up @@ -3583,11 +3583,14 @@ else if (context.getDriverInstanceCount().isPresent()) {
}

List<Type> types = getSourceOperatorTypes(node, context.getTypes());
List<Integer> channels = node.getPartitioningScheme().getPartitioning().getArguments().stream()
List<Integer> partitionChannels = node.getPartitioningScheme().getPartitioning().getArguments().stream()
.map(argument -> node.getOutputSymbols().indexOf(argument.getColumn()))
.collect(toImmutableList());
Optional<Integer> hashChannel = node.getPartitioningScheme().getHashColumn()
.map(symbol -> node.getOutputSymbols().indexOf(symbol));
List<Type> partitionChannelTypes = partitionChannels.stream()
.map(types::get)
.collect(toImmutableList());

List<DriverFactoryParameters> driverFactoryParametersList = new ArrayList<>();
for (int i = 0; i < node.getSources().size(); i++) {
Expand All @@ -3603,8 +3606,8 @@ else if (context.getDriverInstanceCount().isPresent()) {
session,
driverInstanceCount,
node.getPartitioningScheme().getPartitioning().getHandle(),
channels,
types,
partitionChannels,
partitionChannelTypes,
hashChannel,
maxLocalExchangeBufferSize,
blockTypeOperators,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ public void testGatherSingleWriter()
8,
SINGLE_DISTRIBUTION,
ImmutableList.of(),
TYPES,
ImmutableList.of(),
Optional.empty(),
DataSize.ofBytes(retainedSizeOfPages(99)),
TYPE_OPERATOR_FACTORY,
Expand Down Expand Up @@ -185,7 +185,7 @@ public void testBroadcast()
2,
FIXED_BROADCAST_DISTRIBUTION,
ImmutableList.of(),
TYPES,
ImmutableList.of(),
Optional.empty(),
LOCAL_EXCHANGE_MAX_BUFFERED_BYTES,
TYPE_OPERATOR_FACTORY,
Expand Down Expand Up @@ -274,7 +274,7 @@ public void testRandom()
2,
FIXED_ARBITRARY_DISTRIBUTION,
ImmutableList.of(),
TYPES,
ImmutableList.of(),
Optional.empty(),
LOCAL_EXCHANGE_MAX_BUFFERED_BYTES,
TYPE_OPERATOR_FACTORY,
Expand Down Expand Up @@ -325,7 +325,7 @@ public void testScaleWriter()
3,
SCALED_WRITER_DISTRIBUTION,
ImmutableList.of(),
TYPES,
ImmutableList.of(),
Optional.empty(),
DataSize.ofBytes(retainedSizeOfPages(4)),
TYPE_OPERATOR_FACTORY,
Expand Down Expand Up @@ -406,7 +406,7 @@ public void testNoWriterScalingWhenOnlyBufferSizeLimitIsExceeded()
3,
SCALED_WRITER_DISTRIBUTION,
ImmutableList.of(),
TYPES,
ImmutableList.of(),
Optional.empty(),
DataSize.ofBytes(retainedSizeOfPages(4)),
TYPE_OPERATOR_FACTORY,
Expand Down Expand Up @@ -449,7 +449,7 @@ public void testNoWriterScalingWhenOnlyWriterMinSizeLimitIsExceeded()
3,
SCALED_WRITER_DISTRIBUTION,
ImmutableList.of(),
TYPES,
ImmutableList.of(),
Optional.empty(),
DataSize.ofBytes(retainedSizeOfPages(20)),
TYPE_OPERATOR_FACTORY,
Expand Down Expand Up @@ -493,7 +493,7 @@ public void testPassthrough()
2,
FIXED_PASSTHROUGH_DISTRIBUTION,
ImmutableList.of(),
TYPES,
ImmutableList.of(),
Optional.empty(),
DataSize.ofBytes(retainedSizeOfPages(1)),
TYPE_OPERATOR_FACTORY,
Expand Down Expand Up @@ -658,7 +658,7 @@ public BucketFunction getBucketFunction(ConnectorTransactionHandle transactionHa
2,
partitioningHandle,
ImmutableList.of(1),
types,
ImmutableList.of(BIGINT),
Optional.empty(),
LOCAL_EXCHANGE_MAX_BUFFERED_BYTES,
TYPE_OPERATOR_FACTORY,
Expand Down Expand Up @@ -704,15 +704,13 @@ public BucketFunction getBucketFunction(ConnectorTransactionHandle transactionHa
@Test
public void writeUnblockWhenAllReadersFinish()
{
ImmutableList<Type> types = ImmutableList.of(BIGINT);

LocalExchange localExchange = new LocalExchange(
nodePartitioningManager,
SESSION,
2,
FIXED_BROADCAST_DISTRIBUTION,
ImmutableList.of(),
types,
ImmutableList.of(),
Optional.empty(),
LOCAL_EXCHANGE_MAX_BUFFERED_BYTES,
TYPE_OPERATOR_FACTORY,
Expand Down Expand Up @@ -760,7 +758,7 @@ public void writeUnblockWhenAllReadersFinishAndPagesConsumed()
2,
FIXED_BROADCAST_DISTRIBUTION,
ImmutableList.of(),
TYPES,
ImmutableList.of(),
Optional.empty(),
DataSize.ofBytes(1),
TYPE_OPERATOR_FACTORY,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,17 @@ public static BuildSideSetup setupBuildSide(

int partitionCount = parallelBuild ? PARTITION_COUNT : 1;
List<Integer> hashChannels = buildPages.getHashChannels().orElseThrow();
List<Type> types = buildPages.getTypes();
List<Type> hashChannelTypes = hashChannels.stream()
.map(types::get)
.collect(toImmutableList());
LocalExchange localExchange = new LocalExchange(
nodePartitioningManager,
taskContext.getSession(),
partitionCount,
FIXED_HASH_DISTRIBUTION,
hashChannels,
buildPages.getTypes(),
hashChannelTypes,
buildPages.getHashChannel(),
DataSize.of(32, DataSize.Unit.MEGABYTE),
TYPE_OPERATOR_FACTORY,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import io.trino.operator.join.unspilled.HashBuilderOperator.HashBuilderOperatorFactory;
import io.trino.spi.Page;
import io.trino.spi.TrinoException;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeOperators;
import io.trino.sql.gen.JoinFilterFunctionCompiler;
import io.trino.sql.planner.NodePartitioningManager;
Expand Down Expand Up @@ -136,13 +137,17 @@ public static BuildSideSetup setupBuildSide(

int partitionCount = parallelBuild ? PARTITION_COUNT : 1;
List<Integer> hashChannels = buildPages.getHashChannels().orElseThrow();
List<Type> types = buildPages.getTypes();
List<Type> hashChannelTypes = hashChannels.stream()
.map(types::get)
.collect(toImmutableList());
LocalExchange localExchange = new LocalExchange(
nodePartitioningManager,
taskContext.getSession(),
partitionCount,
FIXED_HASH_DISTRIBUTION,
hashChannels,
buildPages.getTypes(),
hashChannelTypes,
buildPages.getHashChannel(),
DataSize.of(32, DataSize.Unit.MEGABYTE),
TYPE_OPERATOR_FACTORY,
Expand Down

0 comments on commit c7e0ecc

Please sign in to comment.