Skip to content

Commit

Permalink
Allow local exchanges on non system partitioning
Browse files Browse the repository at this point in the history
  • Loading branch information
arhimondr committed Sep 5, 2019
1 parent 85efb93 commit 59a230a
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,21 @@
*/
package com.facebook.presto.operator.exchange;

import com.facebook.presto.Session;
import com.facebook.presto.execution.Lifespan;
import com.facebook.presto.operator.BucketPartitionFunction;
import com.facebook.presto.operator.HashGenerator;
import com.facebook.presto.operator.InterpretedHashGenerator;
import com.facebook.presto.operator.PartitionFunction;
import com.facebook.presto.operator.PipelineExecutionStrategy;
import com.facebook.presto.operator.PrecomputedHashGenerator;
import com.facebook.presto.spi.BucketFunction;
import com.facebook.presto.spi.connector.ConnectorBucketNodeMap;
import com.facebook.presto.spi.connector.ConnectorNodePartitioningProvider;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.planner.PartitioningHandle;
import com.facebook.presto.sql.planner.PartitioningProviderManager;
import com.facebook.presto.sql.planner.SystemPartitioningHandle;
import com.google.common.collect.ImmutableList;
import io.airlift.units.DataSize;

Expand All @@ -34,6 +45,7 @@
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.stream.IntStream;
import java.util.stream.Stream;

import static com.facebook.presto.operator.PipelineExecutionStrategy.UNGROUPED_EXECUTION;
Expand Down Expand Up @@ -76,11 +88,13 @@ public class LocalExchange
private int nextSourceIndex;

public LocalExchange(
PartitioningProviderManager partitioningProviderManager,
Session session,
int sinkFactoryCount,
int bufferCount,
PartitioningHandle partitioning,
List<? extends Type> types,
List<Integer> partitionChannels,
List<Type> partitioningChannelTypes,
Optional<Integer> partitionHashChannel,
DataSize maxBufferedBytes)
{
Expand Down Expand Up @@ -110,21 +124,76 @@ else if (partitioning.equals(FIXED_BROADCAST_DISTRIBUTION)) {
else if (partitioning.equals(FIXED_ARBITRARY_DISTRIBUTION)) {
exchangerSupplier = () -> new RandomExchanger(buffers, memoryManager);
}
else if (partitioning.equals(FIXED_HASH_DISTRIBUTION)) {
exchangerSupplier = () -> new PartitioningExchanger(buffers, memoryManager, types, partitionChannels, partitionHashChannel);
}
else if (partitioning.equals(FIXED_PASSTHROUGH_DISTRIBUTION)) {
Iterator<LocalExchangeSource> sourceIterator = this.sources.iterator();
exchangerSupplier = () -> {
checkState(sourceIterator.hasNext(), "no more sources");
return new PassthroughExchanger(sourceIterator.next(), maxBufferedBytes.toBytes() / bufferCount, memoryManager::updateMemoryUsage);
};
}
else if (partitioning.equals(FIXED_HASH_DISTRIBUTION) || partitioning.getConnectorId().isPresent()) {
// partitioned exchange
exchangerSupplier = () -> new PartitioningExchanger(
buffers,
memoryManager,
createPartitionFunction(
partitioningProviderManager,
session,
partitioning,
bufferCount,
partitioningChannelTypes,
partitionHashChannel.isPresent()),
partitionChannels,
partitionHashChannel);
}
else {
throw new IllegalArgumentException("Unsupported local exchange partitioning " + partitioning);
}
}

private static PartitionFunction createPartitionFunction(
PartitioningProviderManager partitioningProviderManager,
Session session,
PartitioningHandle partitioning,
int partitionCount,
List<Type> partitioningChannelTypes,
boolean isHashPrecomputed)
{
if (partitioning.getConnectorHandle() instanceof SystemPartitioningHandle) {
HashGenerator hashGenerator;
if (isHashPrecomputed) {
hashGenerator = new PrecomputedHashGenerator(0);
}
else {
hashGenerator = new InterpretedHashGenerator(partitioningChannelTypes, IntStream.range(0, partitioningChannelTypes.size()).toArray());
}
return new LocalPartitionGenerator(hashGenerator, partitionCount);
}

ConnectorNodePartitioningProvider partitioningProvider = partitioningProviderManager.getPartitioningProvider(partitioning.getConnectorId().get());
ConnectorBucketNodeMap connectorBucketNodeMap = partitioningProvider.getBucketNodeMap(
partitioning.getTransactionHandle().orElse(null),
session.toConnectorSession(),
partitioning.getConnectorHandle());
checkArgument(connectorBucketNodeMap != null, "No partition map %s", partitioning);

int bucketCount = connectorBucketNodeMap.getBucketCount();
int[] bucketToPartition = new int[bucketCount];
for (int bucket = 0; bucket < bucketCount; bucket++) {
bucketToPartition[bucket] = bucket % partitionCount;
}

BucketFunction bucketFunction = partitioningProvider.getBucketFunction(
partitioning.getTransactionHandle().orElse(null),
session.toConnectorSession(),
partitioning.getConnectorHandle(),
partitioningChannelTypes,
bucketCount);

checkArgument(bucketFunction != null, "No bucket function for partitioning: %s", partitioning);
return new BucketPartitionFunction(bucketFunction, bucketToPartition);
}

public int getBufferCount()
{
return sources.size();
Expand Down Expand Up @@ -255,9 +324,11 @@ private static void checkNotHoldsLock(Object lock)
@ThreadSafe
public static class LocalExchangeFactory
{
private final PartitioningProviderManager partitioningProviderManager;
private final Session session;
private final PartitioningHandle partitioning;
private final List<Type> types;
private final List<Integer> partitionChannels;
private final List<Type> partitioningChannelTypes;
private final Optional<Integer> partitionHashChannel;
private final PipelineExecutionStrategy exchangeSourcePipelineExecutionStrategy;
private final DataSize maxBufferedBytes;
Expand All @@ -276,6 +347,8 @@ public static class LocalExchangeFactory
private final List<LocalExchangeSinkFactoryId> closedSinkFactories = new ArrayList<>();

public LocalExchangeFactory(
PartitioningProviderManager partitioningProviderManager,
Session session,
PartitioningHandle partitioning,
int defaultConcurrency,
List<Type> types,
Expand All @@ -284,14 +357,18 @@ public LocalExchangeFactory(
PipelineExecutionStrategy exchangeSourcePipelineExecutionStrategy,
DataSize maxBufferedBytes)
{
this.partitioningProviderManager = requireNonNull(partitioningProviderManager, "partitioningProviderManager is null");
this.session = requireNonNull(session, "session is null");
this.partitioning = requireNonNull(partitioning, "partitioning is null");
this.types = requireNonNull(types, "types is null");
this.partitionChannels = requireNonNull(partitionChannels, "partitioningChannels is null");
this.bufferCount = computeBufferCount(partitioning, defaultConcurrency, partitionChannels);
this.partitionChannels = ImmutableList.copyOf(requireNonNull(partitionChannels, "partitionChannels is null"));
requireNonNull(types, "types is null");
this.partitioningChannelTypes = partitionChannels.stream()
.map(types::get)
.collect(toImmutableList());
this.partitionHashChannel = requireNonNull(partitionHashChannel, "partitionHashChannel is null");
this.exchangeSourcePipelineExecutionStrategy = requireNonNull(exchangeSourcePipelineExecutionStrategy, "exchangeSourcePipelineExecutionStrategy is null");
this.maxBufferedBytes = requireNonNull(maxBufferedBytes, "maxBufferedBytes is null");

this.bufferCount = computeBufferCount(partitioning, defaultConcurrency, partitionChannels);
}

public synchronized LocalExchangeSinkFactoryId newSinkFactoryId()
Expand Down Expand Up @@ -322,8 +399,16 @@ public synchronized LocalExchange getLocalExchange(Lifespan lifespan)
}
return localExchangeMap.computeIfAbsent(lifespan, ignored -> {
checkState(noMoreSinkFactories);
LocalExchange localExchange =
new LocalExchange(numSinkFactories, bufferCount, partitioning, types, partitionChannels, partitionHashChannel, maxBufferedBytes);
LocalExchange localExchange = new LocalExchange(
partitioningProviderManager,
session,
numSinkFactories,
bufferCount,
partitioning,
partitionChannels,
partitioningChannelTypes,
partitionHashChannel,
maxBufferedBytes);
for (LocalExchangeSinkFactoryId closedSinkFactoryId : closedSinkFactories) {
localExchange.getSinkFactory(closedSinkFactoryId).close();
}
Expand Down Expand Up @@ -355,14 +440,14 @@ else if (partitioning.equals(FIXED_ARBITRARY_DISTRIBUTION)) {
bufferCount = defaultConcurrency;
checkArgument(partitionChannels.isEmpty(), "Arbitrary exchange must not have partition channels");
}
else if (partitioning.equals(FIXED_HASH_DISTRIBUTION)) {
bufferCount = defaultConcurrency;
checkArgument(!partitionChannels.isEmpty(), "Partitioned exchange must have partition channels");
}
else if (partitioning.equals(FIXED_PASSTHROUGH_DISTRIBUTION)) {
bufferCount = defaultConcurrency;
checkArgument(partitionChannels.isEmpty(), "Passthrough exchange must not have partition channels");
}
else if (partitioning.equals(FIXED_HASH_DISTRIBUTION) || partitioning.getConnectorId().isPresent()) {
// partitioned exchange
bufferCount = defaultConcurrency;
}
else {
throw new IllegalArgumentException("Unsupported local exchange partitioning " + partitioning);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,10 @@
*/
package com.facebook.presto.operator.exchange;

import com.facebook.presto.operator.HashGenerator;
import com.facebook.presto.operator.InterpretedHashGenerator;
import com.facebook.presto.operator.PrecomputedHashGenerator;
import com.facebook.presto.operator.PartitionFunction;
import com.facebook.presto.spi.Page;
import com.facebook.presto.spi.block.Block;
import com.facebook.presto.spi.type.Type;
import com.google.common.collect.ImmutableList;
import com.google.common.primitives.Ints;
import com.google.common.util.concurrent.ListenableFuture;
import it.unimi.dsi.fastutil.ints.IntArrayList;
import it.unimi.dsi.fastutil.ints.IntList;
Expand All @@ -29,38 +25,30 @@
import java.util.Optional;
import java.util.function.Consumer;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Objects.requireNonNull;

class PartitioningExchanger
implements LocalExchanger
{
private final List<Consumer<PageReference>> buffers;
private final LocalExchangeMemoryManager memoryManager;
private final LocalPartitionGenerator partitionGenerator;
private final PartitionFunction partitionFunction;
private final List<Integer> partitioningChannels;
private final Optional<Integer> hashChannel;
private final IntArrayList[] partitionAssignments;

public PartitioningExchanger(
List<Consumer<PageReference>> partitions,
LocalExchangeMemoryManager memoryManager,
List<? extends Type> types,
List<Integer> partitionChannels,
PartitionFunction partitionFunction,
List<Integer> partitioningChannels,
Optional<Integer> hashChannel)
{
this.buffers = ImmutableList.copyOf(requireNonNull(partitions, "partitions is null"));
this.memoryManager = requireNonNull(memoryManager, "memoryManager is null");

HashGenerator hashGenerator;
if (hashChannel.isPresent()) {
hashGenerator = new PrecomputedHashGenerator(hashChannel.get());
}
else {
List<Type> partitionChannelTypes = partitionChannels.stream()
.map(types::get)
.collect(toImmutableList());
hashGenerator = new InterpretedHashGenerator(partitionChannelTypes, Ints.toArray(partitionChannels));
}
partitionGenerator = new LocalPartitionGenerator(hashGenerator, buffers.size());
this.partitionFunction = requireNonNull(partitionFunction, "partitionFunction is null");
this.partitioningChannels = ImmutableList.copyOf(requireNonNull(partitioningChannels, "partitioningChannels is null"));
this.hashChannel = requireNonNull(hashChannel, "hashChannel is null");

partitionAssignments = new IntArrayList[partitions.size()];
for (int i = 0; i < partitionAssignments.length; i++) {
Expand All @@ -77,8 +65,9 @@ public synchronized void accept(Page page)
}

// assign each row to a partition
for (int position = 0; position < page.getPositionCount(); position++) {
int partition = partitionGenerator.getPartition(page, position);
Page partitioningChannelsPage = extractPartitioningChannels(page);
for (int position = 0; position < partitioningChannelsPage.getPositionCount(); position++) {
int partition = partitionFunction.getPartition(partitioningChannelsPage, position);
partitionAssignments[partition].add(position);
}

Expand All @@ -98,6 +87,21 @@ public synchronized void accept(Page page)
}
}

private Page extractPartitioningChannels(Page inputPage)
{
// hash value is pre-computed, only needs to extract that channel
if (hashChannel.isPresent()) {
return new Page(inputPage.getBlock(hashChannel.get()));
}

// extract partitioning channels
Block[] blocks = new Block[partitioningChannels.size()];
for (int i = 0; i < partitioningChannels.size(); i++) {
blocks[i] = inputPage.getBlock(partitioningChannels.get(i));
}
return new Page(inputPage.getPositionCount(), blocks);
}

@Override
public ListenableFuture<?> waitForWriting()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ public class LocalExecutionPlanner
private final Optional<ExplainAnalyzeContext> explainAnalyzeContext;
private final PageSourceProvider pageSourceProvider;
private final IndexManager indexManager;
private final PartitioningProviderManager partitioningProviderManager;
private final NodePartitioningManager nodePartitioningManager;
private final PageSinkManager pageSinkManager;
private final ExpressionCompiler expressionCompiler;
Expand Down Expand Up @@ -312,6 +313,7 @@ public LocalExecutionPlanner(
Optional<ExplainAnalyzeContext> explainAnalyzeContext,
PageSourceProvider pageSourceProvider,
IndexManager indexManager,
PartitioningProviderManager partitioningProviderManager,
NodePartitioningManager nodePartitioningManager,
PageSinkManager pageSinkManager,
ExpressionCompiler expressionCompiler,
Expand All @@ -332,6 +334,7 @@ public LocalExecutionPlanner(
this.explainAnalyzeContext = requireNonNull(explainAnalyzeContext, "explainAnalyzeContext is null");
this.pageSourceProvider = requireNonNull(pageSourceProvider, "pageSourceProvider is null");
this.indexManager = requireNonNull(indexManager, "indexManager is null");
this.partitioningProviderManager = requireNonNull(partitioningProviderManager, "partitioningProviderManager is null");
this.nodePartitioningManager = requireNonNull(nodePartitioningManager, "nodePartitioningManager is null");
this.metadata = requireNonNull(metadata, "metadata is null");
this.sqlParser = requireNonNull(sqlParser, "sqlParser is null");
Expand Down Expand Up @@ -2422,6 +2425,8 @@ private PhysicalOperation createLocalMerge(ExchangeNode node, LocalExecutionPlan
int operatorsCount = subContext.getDriverInstanceCount().orElse(1);
List<Type> types = getSourceOperatorTypes(node, context.getTypes());
LocalExchangeFactory exchangeFactory = new LocalExchangeFactory(
partitioningProviderManager,
session,
node.getPartitioningScheme().getPartitioning().getHandle(),
operatorsCount,
types,
Expand Down Expand Up @@ -2495,6 +2500,8 @@ else if (context.getDriverInstanceCount().isPresent()) {
}

LocalExchangeFactory localExchangeFactory = new LocalExchangeFactory(
partitioningProviderManager,
session,
node.getPartitioningScheme().getPartitioning().getHandle(),
driverInstanceCount,
types,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,7 @@ public PlanNodeId getNextId()
Optional.empty(),
pageSourceManager,
indexManager,
partitioningProviderManager,
nodePartitioningManager,
pageSinkManager,
expressionCompiler,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ public static LocalExecutionPlanner createTestingPlanner()
Optional.empty(),
pageSourceManager,
new IndexManager(),
partitioningProviderManager,
nodePartitioningManager,
new PageSinkManager(),
new ExpressionCompiler(metadata, pageFunctionCompiler),
Expand Down
Loading

0 comments on commit 59a230a

Please sign in to comment.