Skip to content

Commit

Permalink
Cleanup, refactoring, and add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
kryonix committed Aug 13, 2024
1 parent 4e158f6 commit 7b20670
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ struct FlattenDependentJoins {
column_binding_map_t<idx_t> replacement_map;
const vector<CorrelatedColumnInfo> &correlated_columns;
vector<LogicalType> delim_types;

// Collection of recursive and materialized CTE ids.
// This is used to correctly decorrelate nested EXISTS subqueries.
vector<idx_t> cte_idx;

bool perform_delim;
Expand Down
9 changes: 2 additions & 7 deletions src/include/duckdb/planner/subquery/rewrite_subquery.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ namespace duckdb {
class RewriteSubquery : public LogicalOperatorVisitor {
public:
RewriteSubquery(idx_t table_index, idx_t lateral_depth, ColumnBinding base_binding,
const vector<CorrelatedColumnInfo> &correlated_columns,
column_binding_map_t<idx_t> &correlated_map);
const vector<CorrelatedColumnInfo> &correlated_columns);

void VisitOperator(LogicalOperator &op) override;
unique_ptr<Expression> VisitReplace(BoundSubqueryExpression &expr, unique_ptr<Expression> *expr_ptr) override;
Expand All @@ -29,14 +28,12 @@ class RewriteSubquery : public LogicalOperatorVisitor {
idx_t lateral_depth;
ColumnBinding base_binding;
const vector<CorrelatedColumnInfo> &correlated_columns;
column_binding_map_t<idx_t> &correlated_map;
};

class RewriteCorrelatedSubqueriesRecursive : public BoundNodeVisitor {
public:
RewriteCorrelatedSubqueriesRecursive(idx_t table_index, idx_t lateral_depth, ColumnBinding base_binding,
const vector<CorrelatedColumnInfo> &correlated_columns,
column_binding_map_t<idx_t> &correlated_map);
const vector<CorrelatedColumnInfo> &correlated_columns);

void VisitBoundTableRef(BoundTableRef &ref) override;
void VisitExpression(unique_ptr<Expression> &expression) override;
Expand All @@ -47,10 +44,8 @@ class RewriteCorrelatedSubqueriesRecursive : public BoundNodeVisitor {
idx_t lateral_depth;
ColumnBinding base_binding;
const vector<CorrelatedColumnInfo> &correlated_columns;
column_binding_map_t<idx_t> &correlated_map;
bool add_filter = false;
unique_ptr<Expression> condition;
vector<CorrelatedColumnInfo> add_correlation;
};

} // namespace duckdb
2 changes: 2 additions & 0 deletions src/planner/binder/tableref/bind_basetableref.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ unique_ptr<BoundTableRef> Binder::Bind(BaseTableRef &ref) {
result->types = ctebinding->types;
result->bound_columns = std::move(names);

// Traverse the parent binders in order to find the recursive CTE node
// this CTE reads from. This is necessary to correctly bind correlated columns.
Binder *current = this;
while (current) {
auto rec_cte = current->bound_cte_nodes.find(ctebinding->index);
Expand Down
10 changes: 5 additions & 5 deletions src/planner/subquery/flatten_dependent_join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ unique_ptr<LogicalOperator> FlattenDependentJoins::PushDownDependentJoinInternal
PushDownDependentJoinInternal(std::move(plan->children[0]), parent_propagate_null_values, lateral_depth);

for (auto idx : cte_idx) {
RewriteSubquery subquery_rewriter(idx, lateral_depth, base_binding, correlated_columns, correlated_map);
RewriteSubquery subquery_rewriter(idx, lateral_depth, base_binding, correlated_columns);
subquery_rewriter.VisitOperator(*plan);
}

Expand Down Expand Up @@ -208,7 +208,7 @@ unique_ptr<LogicalOperator> FlattenDependentJoins::PushDownDependentJoinInternal
}

for (auto idx : cte_idx) {
RewriteSubquery subquery_rewriter(idx, lateral_depth, base_binding, correlated_columns, correlated_map);
RewriteSubquery subquery_rewriter(idx, lateral_depth, base_binding, correlated_columns);
subquery_rewriter.VisitOperator(*plan);
}

Expand Down Expand Up @@ -392,7 +392,7 @@ unique_ptr<LogicalOperator> FlattenDependentJoins::PushDownDependentJoinInternal
PushDownDependentJoinInternal(std::move(plan->children[0]), parent_propagate_null_values, lateral_depth);

for (auto idx : cte_idx) {
RewriteSubquery subquery_rewriter(idx, lateral_depth, base_binding, correlated_columns, correlated_map);
RewriteSubquery subquery_rewriter(idx, lateral_depth, base_binding, correlated_columns);
subquery_rewriter.VisitOperator(*plan);
}

Expand Down Expand Up @@ -457,7 +457,7 @@ unique_ptr<LogicalOperator> FlattenDependentJoins::PushDownDependentJoinInternal
parent_propagate_null_values, lateral_depth);

for (auto idx : cte_idx) {
RewriteSubquery subquery_rewriter(idx, lateral_depth, base_binding, correlated_columns, correlated_map);
RewriteSubquery subquery_rewriter(idx, lateral_depth, base_binding, correlated_columns);
subquery_rewriter.VisitOperator(*plan);
}

Expand Down Expand Up @@ -511,7 +511,7 @@ unique_ptr<LogicalOperator> FlattenDependentJoins::PushDownDependentJoinInternal
}

for (auto idx : cte_idx) {
RewriteSubquery subquery_rewriter(idx, lateral_depth, base_binding, correlated_columns, correlated_map);
RewriteSubquery subquery_rewriter(idx, lateral_depth, base_binding, correlated_columns);
subquery_rewriter.VisitOperator(*plan);
}

Expand Down
35 changes: 13 additions & 22 deletions src/planner/subquery/rewrite_subquery.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,34 +16,29 @@
namespace duckdb {

RewriteSubquery::RewriteSubquery(idx_t table_index, idx_t lateral_depth, ColumnBinding base_binding,
const vector<CorrelatedColumnInfo> &correlated_columns,
column_binding_map_t<idx_t> &correlated_map)
const vector<CorrelatedColumnInfo> &correlated_columns)
: table_index(table_index), lateral_depth(lateral_depth), base_binding(base_binding),
correlated_columns(correlated_columns), correlated_map(correlated_map) {
correlated_columns(correlated_columns) {
}

void RewriteSubquery::VisitOperator(duckdb::LogicalOperator &op) {
VisitOperatorExpressions(op);
}

unique_ptr<Expression> RewriteSubquery::VisitReplace(BoundSubqueryExpression &expr, unique_ptr<Expression> *expr_ptr) {
// if (!expr.IsCorrelated()) {
// return nullptr;
// }
// subquery detected within this subquery
// recursively rewrite it using the RewriteCorrelatedRecursive class
RewriteCorrelatedSubqueriesRecursive rewrite(table_index, lateral_depth, this->base_binding, correlated_columns,
correlated_map);
RewriteCorrelatedSubqueriesRecursive rewrite(table_index, lateral_depth, this->base_binding, correlated_columns);
bool rewrite_cte = expr.subquery_type == SubqueryType::EXISTS || expr.subquery_type == SubqueryType::NOT_EXISTS;
rewrite.RewriteCorrelatedSubquery(*expr.binder, *expr.subquery, rewrite_cte);
return nullptr;
}

RewriteCorrelatedSubqueriesRecursive::RewriteCorrelatedSubqueriesRecursive(
idx_t table_index, idx_t lateral_depth, ColumnBinding base_binding,
const vector<CorrelatedColumnInfo> &correlated_columns, column_binding_map_t<idx_t> &correlated_map)
const vector<CorrelatedColumnInfo> &correlated_columns)
: table_index(table_index), lateral_depth(lateral_depth), base_binding(base_binding),
correlated_columns(correlated_columns), correlated_map(correlated_map) {
correlated_columns(correlated_columns) {
}

void RewriteCorrelatedSubqueriesRecursive::VisitBoundTableRef(BoundTableRef &ref) {
Expand All @@ -54,24 +49,25 @@ void RewriteCorrelatedSubqueriesRecursive::VisitBoundTableRef(BoundTableRef &ref
} else if (ref.type == TableReferenceType::CTE && add_filter) {
auto &cteref = ref.Cast<BoundCTERef>();

// check if this is the CTE we are looking for
if (cteref.cte_index == table_index) {
// this is the CTE we are looking for: add a filter to the CTE
// we add a filter to the CTE that compares the correlated columns of the CTE to the correlated columns of
// the outer query. This filter is added to the WHERE clause of the subquery.
for (idx_t i = 0; i < correlated_columns.size(); i++) {
auto &col = correlated_columns[i];
// add_correlation.emplace_back(CorrelatedColumnInfo(col.binding, col.type, col.name, 1));
add_correlation.emplace_back(
CorrelatedColumnInfo(ColumnBinding(base_binding.table_index, base_binding.column_index + i),
col.type, col.name, lateral_depth + 1));
auto row_num_ref = make_uniq<BoundColumnRefExpression>(
auto outer = make_uniq<BoundColumnRefExpression>(
col.name, col.type, ColumnBinding(base_binding.table_index, base_binding.column_index + i),
lateral_depth + 1);

auto row_num_ref1 = make_uniq<BoundColumnRefExpression>(
auto inner = make_uniq<BoundColumnRefExpression>(
col.name, col.type, ColumnBinding(cteref.bind_index, cteref.bound_columns.size() + i));
auto comp = make_uniq<BoundComparisonExpression>(ExpressionType::COMPARE_NOT_DISTINCT_FROM,
row_num_ref->Copy(), row_num_ref1->Copy());
std::move(outer), std::move(inner));
if (condition) {
auto conj = make_uniq<BoundConjunctionExpression>(ExpressionType::CONJUNCTION_AND, std::move(comp),
std::move(condition));
condition = std::move(conj);
} else {
condition = std::move(comp);
}
Expand All @@ -98,11 +94,6 @@ void RewriteCorrelatedSubqueriesRecursive::RewriteCorrelatedSubquery(Binder &bin
query.where_clause = std::move(condition);
}
}

// binder.MergeCorrelatedColumns(add_correlation);
// for (auto const &corr : correlated_columns) {
// binder.AddCorrelatedColumn(corr);
// }
}

void RewriteCorrelatedSubqueriesRecursive::VisitExpression(unique_ptr<Expression> &expression) {
Expand Down

0 comments on commit 7b20670

Please sign in to comment.