Skip to content

Commit

Permalink
Fix 1. projection with filter and 2. Pushdown projection in ReadRel
Browse files Browse the repository at this point in the history
  • Loading branch information
scgkiran committed Jan 17, 2025
1 parent 98a85ea commit 7ed491d
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 12 deletions.
22 changes: 16 additions & 6 deletions src/from_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -776,15 +776,25 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformWriteOp(const substrait::Rel &s
}
auto input = TransformOp(swrite.input());
switch (swrite.op()) {
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_CTAS:
return input->CreateRel(schema_name, table_name);
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_CTAS:
return input->CreateRel(schema_name, table_name);
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_INSERT:
return input->InsertRel(schema_name, table_name);
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_DELETE: {
auto filter = std::move(input.get()->Cast<FilterRelation>());
auto context = filter.child->Cast<TableRelation>().context;
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_DELETE: {
switch (input->type) {
case RelationType::PROJECTION_RELATION: {
auto project = std::move(input.get()->Cast<ProjectionRelation>());
auto filter = std::move(project.child->Cast<FilterRelation>());
return make_shared_ptr<DeleteRelation>(filter.context, std::move(filter.condition), schema_name, table_name);
}
}
case RelationType::FILTER_RELATION: {
auto filter = std::move(input.get()->Cast<FilterRelation>());
return make_shared_ptr<DeleteRelation>(filter.context, std::move(filter.condition), schema_name, table_name);
}
default:
throw NotImplementedException("Unsupported relation type for delete operation");
}
}
default:
throw NotImplementedException("Unsupported write operation " + to_string(swrite.op()));
}
Expand Down
1 change: 1 addition & 0 deletions src/include/to_substrait.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class DuckDBToSubstrait {
substrait::Rel *TransformCreateTable(LogicalOperator &dop);
substrait::Rel *TransformInsertTable(LogicalOperator &dop);
substrait::Rel *TransformDeleteTable(LogicalOperator &dop);
static vector<LogicalType>::__alloc_traits::size_type GetColumnCount(LogicalOperator &dop);
static substrait::Rel *TransformDummyScan();
static substrait::RelCommon *CreateOutputMapping(vector<int32_t> vector);
//! Methods to transform different LogicalGe:75
Expand Down
40 changes: 35 additions & 5 deletions src/to_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -847,10 +847,17 @@ substrait::Rel *DuckDBToSubstrait::TransformFilter(LogicalOperator &dop) {

if (!dfilter.projection_map.empty()) {
auto projection = new substrait::Rel();
projection->mutable_project()->set_allocated_input(res);
auto sproj = projection->mutable_project();
sproj->set_allocated_input(res);
auto child_column_count = GetColumnCount(*dop.children[0]);
auto t_index = 0;
vector<int32_t> output_mapping;
for (auto col_idx : dfilter.projection_map) {
CreateFieldRef(projection->mutable_project()->add_expressions(), col_idx);
CreateFieldRef(sproj->add_expressions(), col_idx);
output_mapping.push_back(child_column_count + t_index);
}
auto rel_common = CreateOutputMapping(output_mapping);
sproj->set_allocated_common(rel_common);
res = projection;
}
return res;
Expand All @@ -869,7 +876,7 @@ substrait::Rel *DuckDBToSubstrait::TransformProjection(LogicalOperator &dop) {
auto res = new substrait::Rel();
auto &dproj = dop.Cast<LogicalProjection>();

auto child_column_count = dop.children[0]->types.size();
auto child_column_count = GetColumnCount(*dop.children[0]);
auto need_output_mapping = true;
if (child_column_count <= dproj.expressions.size()) {
// check if the projection is just pass through of input columns with no reordering
Expand Down Expand Up @@ -1048,12 +1055,12 @@ substrait::Rel *DuckDBToSubstrait::TransformComparisonJoin(LogicalOperator &dop)
// TODO this projection seems redundant but from_substrait does not work without it
auto proj_rel = new substrait::Rel();
auto projection = proj_rel->mutable_project();
auto child_column_count = dop.children[0]->types.size();
auto child_column_count = GetColumnCount(*dop.children[0]);
for (auto left_idx : djoin.left_projection_map) {
CreateFieldRef(projection->add_expressions(), left_idx);
}
if (djoin.join_type != JoinType::SEMI) {
child_column_count += dop.children[1]->types.size();
child_column_count += GetColumnCount(*dop.children[1]);
for (auto right_idx : djoin.right_projection_map) {
CreateFieldRef(projection->add_expressions(), right_idx + left_col_count);
}
Expand Down Expand Up @@ -1391,6 +1398,25 @@ substrait::Rel *DuckDBToSubstrait::TransformGet(LogicalOperator &dop) {
}
projection->set_allocated_select(select);
sget->set_allocated_projection(projection);
} else if (!dget.GetColumnIds().empty()) {
auto &column_ids = dget.GetColumnIds();
vector<int> column_indices;
for (auto &column_id : column_ids) {
if (!column_id.IsRowIdColumn()) {
column_indices.push_back(column_id.GetPrimaryIndex());
}
}
if (!column_indices.empty() && column_indices.size() < dget.returned_types.size()) {
auto projection = new substrait::Expression_MaskExpression();
projection->set_maintain_singular_struct(true);
auto select = new substrait::Expression_MaskExpression_StructSelect();
for (auto col_idx : column_indices) {
auto struct_item = select->add_struct_items();
struct_item->set_field(static_cast<int32_t>(col_idx));
}
projection->set_allocated_select(select);
sget->set_allocated_projection(projection);
}
}

// Add Table Schema
Expand Down Expand Up @@ -1607,6 +1633,10 @@ substrait::Rel *DuckDBToSubstrait::TransformDeleteTable(LogicalOperator &dop) {
return rel;
}

vector<LogicalType>::__alloc_traits::size_type DuckDBToSubstrait::GetColumnCount(LogicalOperator &dop) {
return dop.types.size();
}

substrait::Rel *DuckDBToSubstrait::TransformOp(LogicalOperator &dop) {
switch (dop.type) {
case LogicalOperatorType::LOGICAL_FILTER:
Expand Down
14 changes: 13 additions & 1 deletion test/c/test_projection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,4 +156,16 @@ TEST_CASE("Test Project with duplicate columns", "[substrait-api]") {
auto query_json = R"({"relations":[{"root":{"input":{"project":{"input":{"fetch":{"input":{"read":{"baseSchema":{"names":["i"],"struct":{"types":[{"i32":{"nullability":"NULLABILITY_NULLABLE"}}],"nullability":"NULLABILITY_REQUIRED"}},"projection":{"select":{"structItems":[{}]},"maintainSingularStruct":true},"namedTable":{"names":["integers"]}}},"count":"5"}},"expressions":[{"selection":{"directReference":{"structField":{}},"rootReference":{}}}]}},"names":["i", "integers"]}}],"version":{"minorNumber":53,"producer":"DuckDB"}})";
auto res1 = con.FromSubstraitJSON(query_json);
REQUIRE(CHECK_COLUMN(res1, 0, {1, 2, 3, Value()}));
}
}

TEST_CASE("Test Project simple join on tables with multiple columns", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);
con.EnableQueryVerification();
REQUIRE_NO_FAIL(con.Query("CALL dbgen(sf=0.000001)"));

auto query_text_2 = "SELECT extract(year FROM o_orderdate), l_extendedprice * (1 - l_discount) AS amount FROM lineitem, orders WHERE o_orderkey = l_orderkey";
auto json2 = con.GetSubstraitJSON(query_text_2);
auto expected_json = R"cust_raw({"extensionUris":[{"extensionUriAnchor":1,"uri":"https://github.com/substrait-io/substrait/blob/main/extensions/"},{"extensionUriAnchor":2,"uri":"https://github.com/substrait-io/substrait/blob/main/extensions/functions_datetime.yaml"},{"extensionUriAnchor":3,"uri":"https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic_decimal.yaml"}],"extensions":[{"extensionFunction":{"extensionUriReference":1,"functionAnchor":1,"name":"equal:i64_i64"}},{"extensionFunction":{"extensionUriReference":2,"functionAnchor":2,"name":"extract:date"}},{"extensionFunction":{"extensionUriReference":3,"functionAnchor":3,"name":"subtract:decimal_decimal"}},{"extensionFunction":{"extensionUriReference":3,"functionAnchor":4,"name":"multiply:decimal_decimal"}}],"relations":[{"root":{"input":{"project":{"common":{"emit":{"outputMapping":[3,4]}},"input":{"project":{"common":{"emit":{"outputMapping":[26,27,28]}},"input":{"join":{"left":{"project":{"common":{"emit":{"outputMapping":[7,6,8,5,9,4,10,3,11,2,12,1,13,0,14]}},"input":{"project":{"common":{"emit":{"outputMapping":[3,2,4,1,5,0,6]}},"input":{"read":{"baseSchema":{"names":["l_orderkey","l_partkey","l_suppkey","l_linenumber","l_quantity","l_extendedprice","l_discount","l_tax","l_returnflag","l_linestatus","l_shipdate","l_commitdate","l_receiptdate","l_shipinstruct","l_shipmode","l_comment"],"struct":{"types":[{"i64":{"nullability":"NULLABILITY_REQUIRED"}},{"i64":{"nullability":"NULLABILITY_REQUIRED"}},{"i64":{"nullability":"NULLABILITY_REQUIRED"}},{"i64":{"nullability":"NULLABILITY_REQUIRED"}},{"decimal":{"scale":2,"precision":15,"nullability":"NULLABILITY_REQUIRED"}},{"decimal":{"scale":2,"precision":15,"nullability":"NULLABILITY_REQUIRED"}},{"decimal":{"scale":2,"precision":15,"nullability":"NULLABILITY_REQUIRED"}},{"decimal":{"scale":2,"precision":15,"nullability":"NULLABILITY_REQUIRED"}},{"string":{"nullability":"NULLABILITY_REQUIRED"}},{"string":{"nullability":"NULLABILITY_REQUIRED"}},{"date":{"nullability":"NULLABILITY_REQUIRED"}},{"date":{"nullability":"NULLABILITY_REQUIRED"}},{"date":{"nullability":"NULLABILITY_REQUIRED"}},{"string":{"nullability":"NULLABILITY_REQUIRED"}},{"string":{"nullability":"NULLABILITY_REQUIRED"}},{"string":{"nullability":"NULLABILITY_REQUIRED"}}],"nullability":"NULLABILITY_REQUIRED"}},"projection":{"select":{"structItems":[{},{"field":5},{"field":6}]},"maintainSingularStruct":true},"namedTable":{"names":["lineitem"]}}},"expressions":[{"literal":{"null":{}}},{"literal":{"null":{}}},{"literal":{"null":{}}},{"literal":{"null":{}}}]}},"expressions":[{"literal":{"null":{}}},{"literal":{"null":{}}},{"literal":{"null":{}}},{"literal":{"null":{}}},{"literal":{"null":{}}},{"literal":{"null":{}}},{"literal":{"null":{}}},{"literal":{"null":{}}}]}},"right":{"project":{"common":{"emit":{"outputMapping":[5,4,6,3,7,2,8,1,9,0,10]}},"input":{"project":{"common":{"emit":{"outputMapping":[2,1,3,0,4]}},"input":{"read":{"baseSchema":{"names":["o_orderkey","o_custkey","o_orderstatus","o_totalprice","o_orderdate","o_orderpriority","o_clerk","o_shippriority","o_comment"],"struct":{"types":[{"i64":{"nullability":"NULLABILITY_REQUIRED"}},{"i64":{"nullability":"NULLABILITY_REQUIRED"}},{"string":{"nullability":"NULLABILITY_REQUIRED"}},{"decimal":{"scale":2,"precision":15,"nullability":"NULLABILITY_REQUIRED"}},{"date":{"nullability":"NULLABILITY_REQUIRED"}},{"string":{"nullability":"NULLABILITY_REQUIRED"}},{"string":{"nullability":"NULLABILITY_REQUIRED"}},{"i32":{"nullability":"NULLABILITY_REQUIRED"}},{"string":{"nullability":"NULLABILITY_REQUIRED"}}],"nullability":"NULLABILITY_REQUIRED"}},"projection":{"select":{"structItems":[{},{"field":4}]},"maintainSingularStruct":true},"namedTable":{"names":["orders"]}}},"expressions":[{"literal":{"null":{}}},{"literal":{"null":{}}},{"literal":{"null":{}}}]}},"expressions":[{"literal":{"null":{}}},{"literal":{"null":{}}},{"literal":{"null":{}}},{"literal":{"null":{}}},{"literal":{"null":{}}},{"literal":{"null":{}}}]}},"expression":{"scalarFunction":{"functionReference":1,"outputType":{"bool":{"nullability":"NULLABILITY_NULLABLE"}},"arguments":[{"value":{"selection":{"directReference":{"structField":{"field":3}},"rootReference":{}}}},{"value":{"selection":{"directReference":{"structField":{"field":18}},"rootReference":{}}}}]}},"type":"JOIN_TYPE_INNER"}},"expressions":[{"selection":{"directReference":{"structField":{"field":7}},"rootReference":{}}},{"selection":{"directReference":{"structField":{"field":11}},"rootReference":{}}},{"selection":{"directReference":{"structField":{"field":22}},"rootReference":{}}}]}},"expressions":[{"scalarFunction":{"functionReference":2,"outputType":{"i64":{"nullability":"NULLABILITY_NULLABLE"}},"arguments":[{"enum":"year"},{"value":{"selection":{"directReference":{"structField":{"field":2}},"rootReference":{}}}}]}},{"scalarFunction":{"functionReference":4,"outputType":{"decimal":{"scale":4,"precision":18,"nullability":"NULLABILITY_NULLABLE"}},"arguments":[{"value":{"selection":{"directReference":{"structField":{}},"rootReference":{}}}},{"value":{"scalarFunction":{"functionReference":3,"outputType":{"decimal":{"scale":2,"precision":16,"nullability":"NULLABILITY_NULLABLE"}},"arguments":[{"value":{"literal":{"decimal":{"value":"ZAAAAAAAAAAAAAAAAAAAAA==","precision":16,"scale":2}}}},{"value":{"selection":{"directReference":{"structField":{"field":1}},"rootReference":{}}}}]}}}]}}]}},"names":["\"year\"(o_orderdate)","amount"]}}],"version":{"minorNumber":53,"producer":"DuckDB"}})cust_raw";
REQUIRE(json2 == expected_json);
}

0 comments on commit 7ed491d

Please sign in to comment.