Skip to content

Commit

Permalink
Fix full outer join result mismstach issue
Browse files Browse the repository at this point in the history
  • Loading branch information
JkSelf committed Sep 23, 2024
1 parent 91635cd commit 28c8a14
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 71 deletions.
135 changes: 64 additions & 71 deletions velox/exec/MergeJoin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*/
#include "velox/exec/MergeJoin.h"
#include <iostream>
#include "velox/exec/OperatorUtils.h"
#include "velox/exec/Task.h"
#include "velox/expression/FieldReference.h"
Expand Down Expand Up @@ -353,6 +354,9 @@ void MergeJoin::addOutputRow(
copyRow(right, rightIndex, output_, outputSize_, rightProjections_);
}

// std::cout << "the output in addOutputRow is "
// << output_->toString(0, outputSize_ + 1) << "\n";

if (filter_) {
copyRow(left, leftIndex, filterInput_, outputSize_, filterLeftInputs_);
copyRow(right, rightIndex, filterInput_, outputSize_, filterRightInputs_);
Expand Down Expand Up @@ -531,23 +535,28 @@ bool MergeJoin::addToOutputForLeftJoin() {
auto leftStart = l == firstLeftBatch ? leftStartIndex : 0;
auto leftEnd = l == numLefts - 1 ? leftMatch_->endIndex : left->size();

auto rightEnd = 0;
auto rightStart = 0;
auto firstRightBatch = 0;
auto rightStartIndex = 0;
auto numRights = 0;
for (auto i = leftStart; i < leftEnd; ++i) {
auto firstRightBatch =
firstRightBatch =
(l == firstLeftBatch && i == leftStart && rightMatch_->cursor)
? rightMatch_->cursor->batchIndex
: 0;

auto rightStartIndex =
rightStartIndex =
(l == firstLeftBatch && i == leftStart && rightMatch_->cursor)
? rightMatch_->cursor->index
: rightMatch_->startIndex;

auto numRights = rightMatch_->inputs.size();
numRights = rightMatch_->inputs.size();

for (size_t r = firstRightBatch; r < numRights; ++r) {
auto right = rightMatch_->inputs[r];
auto rightStart = r == firstRightBatch ? rightStartIndex : 0;
auto rightEnd =
r == numRights - 1 ? rightMatch_->endIndex : right->size();
rightStart = r == firstRightBatch ? rightStartIndex : 0;
rightEnd = r == numRights - 1 ? rightMatch_->endIndex : right->size();

if (prepareOutput(left, right)) {
output_->resize(outputSize_);
Expand Down Expand Up @@ -576,10 +585,58 @@ bool MergeJoin::addToOutputForLeftJoin() {
rightMatch_->setCursor(r, j);
return true;
}

addOutputRow(left, i, right, j);
}
}
}

if (isFullJoin(joinType_)) {
// Apply filter to add null in left side if the right side row no matched
// with left side.
auto numRows = (leftEnd - leftStart) * (rightEnd - rightStart);
SelectivityVector matchingRows{outputSize_, false};
matchingRows.setValidRange((outputSize_ - numRows), outputSize_, true);
matchingRows.updateBounds();

evaluateFilter(matchingRows);

auto processedRowNums = (outputSize_ - numRows);
for (size_t r = firstRightBatch; r < numRights; ++r) {
auto right = rightMatch_->inputs[r];
for (auto i = rightStart; i < rightEnd; ++i) {
bool rightMatched = false;
for (auto j = leftStart; j < leftEnd; ++j) {
auto rowIndex = processedRowNums +
(j - leftStart) * (rightEnd - rightStart) + i - rightStart;
const bool passed = !decodedFilterResult_.isNullAt(rowIndex) &&
decodedFilterResult_.valueAt<bool>(rowIndex);
if (passed) {
rightMatched = passed;
}
}

if (!rightMatched) {
// Add new row for the no matched right side row and set the left
// side with null.
if (!isRightFlattened_) {
rawRightIndices_[outputSize_] = i;
} else {
copyRow(right, i, output_, outputSize_, rightProjections_);
}

for (const auto& projection : leftProjections_) {
const auto& target = output_->childAt(projection.outputChannel);
target->setNull(outputSize_, true);
}

joinTracker_->addMiss(outputSize_);

++outputSize_;
}
}
}
}
}

leftMatch_.reset();
Expand Down Expand Up @@ -1128,8 +1185,6 @@ RowVectorPtr MergeJoin::doGetOutput() {
RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
const auto numRows = output->size();

RowVectorPtr fullOuterOutput = nullptr;

BufferPtr indices = allocateIndices(numRows, pool());
auto rawIndices = indices->asMutable<vector_size_t>();
vector_size_t numPassed = 0;
Expand All @@ -1150,61 +1205,7 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
if (!isAntiJoin(joinType_)) {
rawIndices[numPassed++] = row;

if (isFullJoin(joinType_)) {
// For filtered rows, it is necessary to insert additional data
// to ensure the result set is complete. Specifically, we
// need to generate two records: one record containing the
// columns from the left table along with nulls for the
// right table, and another record containing the columns
// from the right table along with nulls for the left table.
// For instance, the current output is filtered based on the condition
// t > 1.

// 1, 1
// 2, 2
// 3, 3

// In this scenario, we need to additionally insert a record 1, 1.
// Subsequently, we will set the values of the columns on the left to
// null and the values of the columns on the right to null as well. By
// doing so, we will obtain the final result set.

// 1, null
// null, 1
// 2, 2
// 3, 3
fullOuterOutput = BaseVector::create<RowVector>(
output->type(), output->size() + 1, pool());

for (auto i = 0; i < row + 1; i++) {
for (auto j = 0; j < output->type()->size(); j++) {
fullOuterOutput->childAt(j)->copy(
output->childAt(j).get(), i, i, 1);
}
}

for (auto j = 0; j < output->type()->size(); j++) {
fullOuterOutput->childAt(j)->copy(
output->childAt(j).get(), row + 1, row, 1);
}

for (auto i = row + 1; i < output->size(); i++) {
for (auto j = 0; j < output->type()->size(); j++) {
fullOuterOutput->childAt(j)->copy(
output->childAt(j).get(), i + 1, i, 1);
}
}

for (auto& projection : leftProjections_) {
auto target = fullOuterOutput->childAt(projection.outputChannel);
target->setNull(row, true);
}

for (auto& projection : rightProjections_) {
auto target = fullOuterOutput->childAt(projection.outputChannel);
target->setNull(row + 1, true);
}
} else if (!isRightJoin(joinType_)) {
if (!isRightJoin(joinType_)) {
for (auto& projection : rightProjections_) {
auto target = output->childAt(projection.outputChannel);
target->setNull(row, true);
Expand Down Expand Up @@ -1284,17 +1285,9 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {

if (numPassed == numRows) {
// All rows passed.
if (fullOuterOutput) {
return fullOuterOutput;
}
return output;
}

// Some, but not all rows passed.
if (fullOuterOutput) {
return wrap(numPassed, indices, fullOuterOutput);
}

return wrap(numPassed, indices, output);
}

Expand Down
39 changes: 39 additions & 0 deletions velox/exec/tests/MergeJoinTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,45 @@ TEST_F(MergeJoinTest, leftAndRightJoinFilter) {
}
}

TEST_F(MergeJoinTest, fullOuterJoinWithDuplicateMatch) {
// Each row on the left side has at most one match on the right side.
auto left = makeRowVector(
{"a", "b"},
{
makeNullableFlatVector<int32_t>({1, 2, 2, 2, 3, 5, 6, std::nullopt}),
makeNullableFlatVector<double>(
{2.0, 100.0, 1.0, 1.0, 3.0, 1.0, 6.0, std::nullopt}),
});

auto right = makeRowVector(
{"c", "d"},
{
makeNullableFlatVector<int32_t>(
{0, 2, 2, 2, 2, 3, 4, 5, 7, std::nullopt}),
makeNullableFlatVector<double>(
{0.0, 3.0, -1.0, -1.0, 3.0, 2.0, 1.0, 3.0, 7.0, std::nullopt}),
});

createDuckDbTable("t", {left});
createDuckDbTable("u", {right});

auto planNodeIdGenerator = std::make_shared<core::PlanNodeIdGenerator>();

auto rightPlan =
PlanBuilder(planNodeIdGenerator)
.values({left})
.mergeJoin(
{"a"},
{"c"},
PlanBuilder(planNodeIdGenerator).values({right}).planNode(),
"b < d",
{"a", "b", "c", "d"},
core::JoinType::kFull)
.planNode();
AssertQueryBuilder(rightPlan, duckDbQueryRunner_)
.assertResults("SELECT * from t FULL OUTER JOIN u ON a = c AND b < d");
}

TEST_F(MergeJoinTest, rightJoinWithDuplicateMatch) {
// Each row on the left side has at most one match on the right side.
auto left = makeRowVector(
Expand Down

0 comments on commit 28c8a14

Please sign in to comment.