Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support RLE Pages in JoinProbe #14493

Merged
merged 2 commits into from
May 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public OperatorFactory join(
probeOutputChannelTypes,
lookupSourceFactory.getBuildOutputTypes(),
joinType,
new JoinProbe.JoinProbeFactory(probeOutputChannels, probeJoinChannel, probeHashChannel),
new JoinProbe.JoinProbeFactory(probeOutputChannels, probeJoinChannel, probeHashChannel, hasFilter),
blockTypeOperators,
probeJoinChannel,
probeHashChannel);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,31 +33,37 @@ public class JoinOperatorInfo
private final long[] logHistogramProbes;
private final long[] logHistogramOutput;
private final Optional<Long> lookupSourcePositions;
private final long rleProbes;
private final long totalProbes;

public static JoinOperatorInfo createJoinOperatorInfo(JoinType joinType, long[] logHistogramCounters, Optional<Long> lookupSourcePositions)
public static JoinOperatorInfo createJoinOperatorInfo(JoinType joinType, long[] logHistogramCounters, Optional<Long> lookupSourcePositions, long rleProbes, long totalProbes)
{
long[] logHistogramProbes = new long[HISTOGRAM_BUCKETS];
long[] logHistogramOutput = new long[HISTOGRAM_BUCKETS];
for (int i = 0; i < HISTOGRAM_BUCKETS; i++) {
logHistogramProbes[i] = logHistogramCounters[2 * i];
logHistogramOutput[i] = logHistogramCounters[2 * i + 1];
}
return new JoinOperatorInfo(joinType, logHistogramProbes, logHistogramOutput, lookupSourcePositions);
return new JoinOperatorInfo(joinType, logHistogramProbes, logHistogramOutput, lookupSourcePositions, rleProbes, totalProbes);
}

@JsonCreator
public JoinOperatorInfo(
@JsonProperty("joinType") JoinType joinType,
@JsonProperty("logHistogramProbes") long[] logHistogramProbes,
@JsonProperty("logHistogramOutput") long[] logHistogramOutput,
@JsonProperty("lookupSourcePositions") Optional<Long> lookupSourcePositions)
@JsonProperty("lookupSourcePositions") Optional<Long> lookupSourcePositions,
@JsonProperty("rleProbes") long rleProbes,
@JsonProperty("totalProbes") long totalProbes)
{
checkArgument(logHistogramProbes.length == HISTOGRAM_BUCKETS);
checkArgument(logHistogramOutput.length == HISTOGRAM_BUCKETS);
this.joinType = joinType;
this.logHistogramProbes = logHistogramProbes;
this.logHistogramOutput = logHistogramOutput;
this.lookupSourcePositions = lookupSourcePositions;
this.rleProbes = rleProbes;
this.totalProbes = totalProbes;
}

@JsonProperty
Expand Down Expand Up @@ -87,6 +93,18 @@ public Optional<Long> getLookupSourcePositions()
return lookupSourcePositions;
}

@JsonProperty
public long getRleProbes()
{
return rleProbes;
}

@JsonProperty
public long getTotalProbes()
{
return totalProbes;
}

@Override
public String toString()
{
Expand All @@ -95,6 +113,8 @@ public String toString()
.add("logHistogramProbes", logHistogramProbes)
.add("logHistogramOutput", logHistogramOutput)
.add("lookupSourcePositions", lookupSourcePositions)
.add("rleProbes", rleProbes)
.add("totalProbes", totalProbes)
.toString();
}

Expand All @@ -114,7 +134,7 @@ public JoinOperatorInfo mergeWith(JoinOperatorInfo other)
mergedSourcePositions = Optional.of(this.lookupSourcePositions.orElse(0L) + other.lookupSourcePositions.orElse(0L));
}

return new JoinOperatorInfo(this.joinType, logHistogramProbes, logHistogramOutput, mergedSourcePositions);
return new JoinOperatorInfo(this.joinType, logHistogramProbes, logHistogramOutput, mergedSourcePositions, this.rleProbes + other.rleProbes, this.totalProbes + other.totalProbes);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ public class JoinStatisticsCounter
// [2*bucket + 1] total count of rows that were produces by probe rows in this bucket.
private final long[] logHistogramCounters = new long[HISTOGRAM_BUCKETS * 2];

private long rleProbes;
private long totalProbes;

/**
* Estimated number of positions in on the build side
*/
Expand Down Expand Up @@ -71,9 +74,19 @@ else if (numSourcePositions <= 100) {
logHistogramCounters[2 * bucket + 1] += numSourcePositions;
}

public void recordRleProbe()
{
rleProbes++;
}

public void recordCreateProbe()
{
totalProbes++;
}

@Override
public JoinOperatorInfo get()
{
return createJoinOperatorInfo(joinType, logHistogramCounters, lookupSourcePositions);
return createJoinOperatorInfo(joinType, logHistogramCounters, lookupSourcePositions, rleProbes, totalProbes);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,52 +17,61 @@
import io.trino.operator.join.LookupSource;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.block.RunLengthEncodedBlock;

import javax.annotation.Nullable;

import java.util.Arrays;
import java.util.List;
import java.util.OptionalInt;
import java.util.stream.IntStream;

import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.spi.type.BigintType.BIGINT;
import static java.util.Objects.requireNonNull;

/**
* This class eagerly calculates all join positions and stores them in an array
* PageJoiner is responsible for ensuring that only the first position is processed for RLE with no or single build row match
*/
public class JoinProbe
raunaqmorarka marked this conversation as resolved.
Show resolved Hide resolved
{
public static class JoinProbeFactory
{
private final int[] probeOutputChannels;
private final int[] probeJoinChannels;
private final int probeHashChannel; // only valid when >= 0
private final boolean hasFilter;

public JoinProbeFactory(List<Integer> probeOutputChannels, List<Integer> probeJoinChannels, OptionalInt probeHashChannel)
public JoinProbeFactory(List<Integer> probeOutputChannels, List<Integer> probeJoinChannels, OptionalInt probeHashChannel, boolean hasFilter)
{
this.probeOutputChannels = Ints.toArray(requireNonNull(probeOutputChannels, "probeOutputChannels is null"));
this.probeJoinChannels = Ints.toArray(requireNonNull(probeJoinChannels, "probeJoinChannels is null"));
this.probeHashChannel = requireNonNull(probeHashChannel, "probeHashChannel is null").orElse(-1);
this.hasFilter = hasFilter;
}

public JoinProbe createJoinProbe(Page page, LookupSource lookupSource)
{
Page probePage = page.getLoadedPage(probeJoinChannels);
return new JoinProbe(probeOutputChannels, page, probePage, lookupSource, probeHashChannel >= 0 ? page.getBlock(probeHashChannel).getLoadedBlock() : null);
return new JoinProbe(probeOutputChannels, page, probePage, lookupSource, probeHashChannel >= 0 ? page.getBlock(probeHashChannel).getLoadedBlock() : null, hasFilter);
}
}

private final int[] probeOutputChannels;
private final Page page;
private final long[] joinPositionCache;
private final boolean isRle;
skrzypo987 marked this conversation as resolved.
Show resolved Hide resolved
private int position = -1;

private JoinProbe(int[] probeOutputChannels, Page page, Page probePage, LookupSource lookupSource, @Nullable Block probeHashBlock)
private JoinProbe(int[] probeOutputChannels, Page page, Page probePage, LookupSource lookupSource, @Nullable Block probeHashBlock, boolean hasFilter)
{
this.probeOutputChannels = requireNonNull(probeOutputChannels, "probeOutputChannels is null");
this.page = requireNonNull(page, "page is null");

joinPositionCache = fillCache(lookupSource, page, probeHashBlock, probePage);
// if filter channels are not RLE encoded, then every probe
// row might be unique and must be matched independently
this.isRle = !hasFilter && hasOnlyRleBlocks(probePage);
joinPositionCache = fillCache(lookupSource, page, probeHashBlock, probePage, isRle);
}

public int[] getOutputChannels()
Expand All @@ -76,6 +85,11 @@ public boolean advanceNextPosition()
return !isFinished();
}

public void finish()
{
position = page.getPositionCount();
}

public boolean isFinished()
{
return position == page.getPositionCount();
Expand All @@ -91,6 +105,11 @@ public int getPosition()
return position;
}

public boolean areProbeJoinChannelsRunLengthEncoded()
{
return isRle;
}

public Page getPage()
{
return page;
Expand All @@ -100,19 +119,49 @@ private static long[] fillCache(
LookupSource lookupSource,
Page page,
Block probeHashBlock,
Page probePage)
Page probePage,
boolean isRle)
{
int positionCount = page.getPositionCount();
List<Block> nullableBlocks = IntStream.range(0, probePage.getChannelCount())
.mapToObj(i -> probePage.getBlock(i))
.filter(Block::mayHaveNull)
.collect(toImmutableList());

Block[] nullableBlocks = new Block[probePage.getChannelCount()];
int nullableBlocksCount = 0;
for (int channel = 0; channel < probePage.getChannelCount(); channel++) {
Block probeBlock = probePage.getBlock(channel);
if (probeBlock.mayHaveNull()) {
nullableBlocks[nullableBlocksCount++] = probeBlock;
}
}

if (isRle) {
raunaqmorarka marked this conversation as resolved.
Show resolved Hide resolved
long[] joinPositionCache;
// Null values cannot be joined, so if any column contains null, there is no match
boolean anyAllNullsBlock = false;
for (int i = 0; i < nullableBlocksCount; i++) {
Block nullableBlock = nullableBlocks[i];
if (nullableBlock.isNull(0)) {
anyAllNullsBlock = true;
break;
}
}
if (anyAllNullsBlock) {
joinPositionCache = new long[1];
joinPositionCache[0] = -1;
}
else {
joinPositionCache = new long[positionCount];
// We can fall back to processing all positions in case there are multiple build rows matched for the first probe position
Arrays.fill(joinPositionCache, lookupSource.getJoinPosition(0, probePage, page));
}

return joinPositionCache;
}

long[] joinPositionCache = new long[positionCount];
if (!nullableBlocks.isEmpty()) {
if (nullableBlocksCount > 0) {
Arrays.fill(joinPositionCache, -1);
boolean[] isNull = new boolean[positionCount];
int nonNullCount = getIsNull(nullableBlocks, positionCount, isNull);
int nonNullCount = getIsNull(nullableBlocks, nullableBlocksCount, positionCount, isNull);
if (nonNullCount < positionCount) {
// We only store positions that are not null
int[] positions = new int[nonNullCount];
Expand Down Expand Up @@ -155,22 +204,36 @@ private static long[] fillCache(
return joinPositionCache;
}

private static int getIsNull(List<Block> nullableBlocks, int positionCount, boolean[] isNull)
private static int getIsNull(Block[] nullableBlocks, int nullableBlocksCount, int positionCount, boolean[] isNull)
{
for (int i = 0; i < nullableBlocks.size() - 1; i++) {
Block block = nullableBlocks.get(i);
for (int i = 0; i < nullableBlocksCount - 1; i++) {
Block block = nullableBlocks[i];
for (int position = 0; position < positionCount; position++) {
isNull[position] |= block.isNull(position);
}
}
// Last block will also calculate `nonNullCount`
int nonNullCount = 0;
Block lastBlock = nullableBlocks.get(nullableBlocks.size() - 1);
Block lastBlock = nullableBlocks[nullableBlocksCount - 1];
for (int position = 0; position < positionCount; position++) {
isNull[position] |= lastBlock.isNull(position);
nonNullCount += isNull[position] ? 0 : 1;
}

return nonNullCount;
}

private static boolean hasOnlyRleBlocks(Page probePage)
{
if (probePage.getChannelCount() == 0) {
return false;
}

for (int i = 0; i < probePage.getChannelCount(); i++) {
if (!(probePage.getBlock(i) instanceof RunLengthEncodedBlock)) {
return false;
}
}
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import io.trino.spi.Page;
import io.trino.spi.PageBuilder;
import io.trino.spi.block.Block;
import io.trino.spi.block.RunLengthEncodedBlock;
import io.trino.spi.type.Type;
import it.unimi.dsi.fastutil.ints.IntArrayList;

Expand All @@ -43,6 +44,7 @@ public class LookupJoinPageBuilder
private int estimatedProbeRowSize = -1;
private int previousPosition = -1;
private boolean isSequentialProbeIndices = true;
private boolean repeatBuildRow;

public LookupJoinPageBuilder(List<Type> buildTypes)
{
Expand All @@ -62,6 +64,13 @@ public boolean isEmpty()
return probeIndexBuilder.isEmpty() && buildPageBuilder.isEmpty();
}

public int getPositionCount()
{
// when build rows are repeated then position count is equal to probe position count
verify(!repeatBuildRow);
return probeIndexBuilder.size();
}

public void reset()
{
// be aware that probeIndexBuilder will not clear its capacity
Expand All @@ -71,6 +80,7 @@ public void reset()
estimatedProbeRowSize = -1;
previousPosition = -1;
isSequentialProbeIndices = true;
repeatBuildRow = false;
}

/**
Expand Down Expand Up @@ -101,8 +111,17 @@ public void appendNullForBuild(JoinProbe probe)
}
}

public void repeatBuildRow()
{
repeatBuildRow = true;
}

public Page build(JoinProbe probe)
{
if (repeatBuildRow) {
return buildRepeatedPage(probe);
}

int outputPositions = probeIndexBuilder.size();
verify(buildPageBuilder.getPositionCount() == outputPositions);

Expand Down Expand Up @@ -140,6 +159,32 @@ public Page build(JoinProbe probe)
return new Page(outputPositions, blocks);
}

private Page buildRepeatedPage(JoinProbe probe)
{
// Build match can be repeated only if there is a single build row match
// and probe join channels are run length encoded.
verify(probe.areProbeJoinChannelsRunLengthEncoded());
verify(buildPageBuilder.getPositionCount() == 1);
verify(probeIndexBuilder.size() == 1);
verify(probeIndexBuilder.getInt(0) == 0);

int positionCount = probe.getPage().getPositionCount();
int[] probeOutputChannels = probe.getOutputChannels();
Block[] blocks = new Block[probeOutputChannels.length + buildOutputChannelCount];

for (int i = 0; i < probeOutputChannels.length; i++) {
blocks[i] = probe.getPage().getBlock(probeOutputChannels[i]);
}

int offset = probeOutputChannels.length;
for (int i = 0; i < buildOutputChannelCount; i++) {
Block buildBlock = buildPageBuilder.getBlockBuilder(i).build();
blocks[offset + i] = RunLengthEncodedBlock.create(buildBlock, positionCount);
}

return new Page(positionCount, blocks);
}

@Override
public String toString()
{
Expand Down
Loading