Skip to content

Commit

Permalink
Unify handling of materialized CTEs using CTE map
Browse files Browse the repository at this point in the history
  • Loading branch information
kryonix committed Mar 18, 2024
1 parent 2d99bed commit 3f3f7b2
Show file tree
Hide file tree
Showing 9 changed files with 30 additions and 37 deletions.
6 changes: 2 additions & 4 deletions src/include/duckdb/parser/transformer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,10 +286,8 @@ class Transformer {
OnCreateConflict TransformOnConflict(duckdb_libpgquery::PGOnCreateConflict conflict);
string TransformAlias(duckdb_libpgquery::PGAlias *root, vector<string> &column_name_alias);
vector<string> TransformStringList(duckdb_libpgquery::PGList *list);
void TransformCTE(duckdb_libpgquery::PGWithClause &de_with_clause, CommonTableExpressionMap &cte_map,
vector<unique_ptr<CTENode>> &materialized_ctes);
static unique_ptr<QueryNode> TransformMaterializedCTE(unique_ptr<QueryNode> root,
vector<unique_ptr<CTENode>> &materialized_ctes);
void TransformCTE(duckdb_libpgquery::PGWithClause &de_with_clause, CommonTableExpressionMap &cte_map);
static unique_ptr<QueryNode> TransformMaterializedCTE(unique_ptr<QueryNode> root);
unique_ptr<SelectStatement> TransformRecursiveCTE(duckdb_libpgquery::PGCommonTableExpr &node,
CommonTableExpressionInfo &info);

Expand Down
5 changes: 1 addition & 4 deletions src/parser/transform/helpers/transform_cte.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ void Transformer::ExtractCTEsRecursive(CommonTableExpressionMap &cte_map) {
}
}

void Transformer::TransformCTE(duckdb_libpgquery::PGWithClause &de_with_clause, CommonTableExpressionMap &cte_map,
vector<unique_ptr<CTENode>> &materialized_ctes) {
void Transformer::TransformCTE(duckdb_libpgquery::PGWithClause &de_with_clause, CommonTableExpressionMap &cte_map) {
stored_cte_map.push_back(&cte_map);

// TODO: might need to update in case of future lawsuit
Expand Down Expand Up @@ -92,8 +91,6 @@ void Transformer::TransformCTE(duckdb_libpgquery::PGWithClause &de_with_clause,
materialize->query = info->query->node->Copy();
materialize->ctename = cte_name;
materialize->aliases = info->aliases;
materialized_ctes.push_back(std::move(materialize));

info->materialized = CTEMaterialize::CTE_MATERIALIZE_ALWAYS;
}

Expand Down
4 changes: 1 addition & 3 deletions src/parser/transform/statement/transform_delete.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@ namespace duckdb {

unique_ptr<DeleteStatement> Transformer::TransformDelete(duckdb_libpgquery::PGDeleteStmt &stmt) {
auto result = make_uniq<DeleteStatement>();
vector<unique_ptr<CTENode>> materialized_ctes;
if (stmt.withClause) {
TransformCTE(*PGPointerCast<duckdb_libpgquery::PGWithClause>(stmt.withClause), result->cte_map,
materialized_ctes);
TransformCTE(*PGPointerCast<duckdb_libpgquery::PGWithClause>(stmt.withClause), result->cte_map);
}

result->condition = TransformExpression(stmt.whereClause);
Expand Down
7 changes: 2 additions & 5 deletions src/parser/transform/statement/transform_insert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,8 @@ unique_ptr<TableRef> Transformer::TransformValuesList(duckdb_libpgquery::PGList

unique_ptr<InsertStatement> Transformer::TransformInsert(duckdb_libpgquery::PGInsertStmt &stmt) {
auto result = make_uniq<InsertStatement>();
vector<unique_ptr<CTENode>> materialized_ctes;
if (stmt.withClause) {
TransformCTE(*PGPointerCast<duckdb_libpgquery::PGWithClause>(stmt.withClause), result->cte_map,
materialized_ctes);
TransformCTE(*PGPointerCast<duckdb_libpgquery::PGWithClause>(stmt.withClause), result->cte_map);
}

// first check if there are any columns specified
Expand All @@ -44,8 +42,7 @@ unique_ptr<InsertStatement> Transformer::TransformInsert(duckdb_libpgquery::PGIn
}
if (stmt.selectStmt) {
result->select_statement = TransformSelect(stmt.selectStmt, false);
result->select_statement->node =
TransformMaterializedCTE(std::move(result->select_statement->node), materialized_ctes);
result->select_statement->node = TransformMaterializedCTE(std::move(result->select_statement->node));
} else {
result->default_values = true;
}
Expand Down
8 changes: 2 additions & 6 deletions src/parser/transform/statement/transform_pivot_stmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,9 @@ unique_ptr<QueryNode> Transformer::TransformPivotStatement(duckdb_libpgquery::PG
bool has_parameters = next_param_count > current_param_count;

auto select_node = make_uniq<SelectNode>();
vector<unique_ptr<CTENode>> materialized_ctes;
// handle the CTEs
if (select.withClause) {
TransformCTE(*PGPointerCast<duckdb_libpgquery::PGWithClause>(select.withClause), select_node->cte_map,
materialized_ctes);
TransformCTE(*PGPointerCast<duckdb_libpgquery::PGWithClause>(select.withClause), select_node->cte_map);
}
if (!pivot->columns) {
// no pivot columns - not actually a pivot
Expand Down Expand Up @@ -215,9 +213,7 @@ unique_ptr<QueryNode> Transformer::TransformPivotStatement(duckdb_libpgquery::PG
// transform order by/limit modifiers
TransformModifiers(select, *select_node);

auto node = Transformer::TransformMaterializedCTE(std::move(select_node), materialized_ctes);

return node;
return select_node;
}

} // namespace duckdb
7 changes: 5 additions & 2 deletions src/parser/transform/statement/transform_select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
namespace duckdb {

unique_ptr<QueryNode> Transformer::TransformSelectNode(duckdb_libpgquery::PGSelectStmt &select) {
unique_ptr<QueryNode> stmt = nullptr;
if (select.pivot) {
return TransformPivotStatement(select);
stmt = TransformPivotStatement(select);
} else {
return TransformSelectInternal(select);
stmt = TransformSelectInternal(select);
}

return TransformMaterializedCTE(std::move(stmt));
}

unique_ptr<SelectStatement> Transformer::TransformSelect(duckdb_libpgquery::PGSelectStmt &select, bool is_select) {
Expand Down
10 changes: 2 additions & 8 deletions src/parser/transform/statement/transform_select_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,13 @@ unique_ptr<QueryNode> Transformer::TransformSelectInternal(duckdb_libpgquery::PG
auto stack_checker = StackCheck();

unique_ptr<QueryNode> node;
vector<unique_ptr<CTENode>> materialized_ctes;

switch (stmt.op) {
case duckdb_libpgquery::PG_SETOP_NONE: {
node = make_uniq<SelectNode>();
auto &result = node->Cast<SelectNode>();
if (stmt.withClause) {
TransformCTE(*PGPointerCast<duckdb_libpgquery::PGWithClause>(stmt.withClause), node->cte_map,
materialized_ctes);
TransformCTE(*PGPointerCast<duckdb_libpgquery::PGWithClause>(stmt.withClause), node->cte_map);
}
if (stmt.windowClause) {
for (auto window_ele = stmt.windowClause->head; window_ele != nullptr; window_ele = window_ele->next) {
Expand Down Expand Up @@ -117,8 +115,7 @@ unique_ptr<QueryNode> Transformer::TransformSelectInternal(duckdb_libpgquery::PG
node = make_uniq<SetOperationNode>();
auto &result = node->Cast<SetOperationNode>();
if (stmt.withClause) {
TransformCTE(*PGPointerCast<duckdb_libpgquery::PGWithClause>(stmt.withClause), node->cte_map,
materialized_ctes);
TransformCTE(*PGPointerCast<duckdb_libpgquery::PGWithClause>(stmt.withClause), node->cte_map);
}
result.left = TransformSelectNode(*stmt.larg);
result.right = TransformSelectNode(*stmt.rarg);
Expand Down Expand Up @@ -154,9 +151,6 @@ unique_ptr<QueryNode> Transformer::TransformSelectInternal(duckdb_libpgquery::PG

TransformModifiers(stmt, *node);

// Handle materialized CTEs
node = Transformer::TransformMaterializedCTE(std::move(node), materialized_ctes);

return node;
}

Expand Down
4 changes: 1 addition & 3 deletions src/parser/transform/statement/transform_update.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@ unique_ptr<UpdateSetInfo> Transformer::TransformUpdateSetInfo(duckdb_libpgquery:

unique_ptr<UpdateStatement> Transformer::TransformUpdate(duckdb_libpgquery::PGUpdateStmt &stmt) {
auto result = make_uniq<UpdateStatement>();
vector<unique_ptr<CTENode>> materialized_ctes;
if (stmt.withClause) {
TransformCTE(*PGPointerCast<duckdb_libpgquery::PGWithClause>(stmt.withClause), result->cte_map,
materialized_ctes);
TransformCTE(*PGPointerCast<duckdb_libpgquery::PGWithClause>(stmt.withClause), result->cte_map);
}

result->table = TransformRangeVar(*stmt.relation);
Expand Down
16 changes: 14 additions & 2 deletions src/parser/transformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,20 @@ unique_ptr<SQLStatement> Transformer::TransformStatementInternal(duckdb_libpgque
}
}

unique_ptr<QueryNode> Transformer::TransformMaterializedCTE(unique_ptr<QueryNode> root,
vector<unique_ptr<CTENode>> &materialized_ctes) {
unique_ptr<QueryNode> Transformer::TransformMaterializedCTE(unique_ptr<QueryNode> root) {
// Extract materialized CTEs from cte_map
vector<unique_ptr<CTENode>> materialized_ctes;
for (auto &cte : root->cte_map.map) {
auto &cte_entry = cte.second;
if (cte_entry->materialized == CTEMaterialize::CTE_MATERIALIZE_ALWAYS) {
auto mat_cte = make_uniq<CTENode>();
mat_cte->ctename = cte.first;
mat_cte->query = cte_entry->query->node->Copy();
mat_cte->aliases = cte_entry->aliases;
materialized_ctes.push_back(std::move(mat_cte));
}
}

while (!materialized_ctes.empty()) {
unique_ptr<CTENode> node_result;
node_result = std::move(materialized_ctes.back());
Expand Down

0 comments on commit 3f3f7b2

Please sign in to comment.