Skip to content

Commit

Permalink
project rel to and from substrait to include pass through columns
Browse files Browse the repository at this point in the history
  • Loading branch information
scgkiran committed Jan 8, 2025
1 parent 80fb812 commit b11ab6f
Show file tree
Hide file tree
Showing 5 changed files with 246 additions and 7 deletions.
45 changes: 41 additions & 4 deletions src/from_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -492,22 +492,59 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformFilterOp(const substrait::Rel &
return make_shared_ptr<FilterRelation>(TransformOp(sfilter.input()), TransformExpr(sfilter.condition()));
}

const google::protobuf::RepeatedField<int32_t>& GetOutputMapping(const substrait::Rel &sop) {
const substrait::RelCommon* common = nullptr;
switch (sop.rel_type_case()) {
case substrait::Rel::RelTypeCase::kJoin:
common = &sop.join().common();
break;
case substrait::Rel::RelTypeCase::kProject:
common = &sop.project().common();
break;
default:
throw InternalException("Unsupported relation type " + to_string(sop.rel_type_case()));
}
if (!common->has_emit()) {
static google::protobuf::RepeatedField<int32_t> empty_mapping;
return empty_mapping;
}
return common->emit().output_mapping();
}

shared_ptr<Relation>
SubstraitToDuckDB::TransformProjectOp(const substrait::Rel &sop,
const google::protobuf::RepeatedPtrField<std::string> *names) {
vector<unique_ptr<ParsedExpression>> expressions;
RootNameIterator iterator(names);

for (auto &sexpr : sop.project().expressions()) {
expressions.push_back(TransformExpr(sexpr, &iterator));
auto input_rel = TransformOp(sop.project().input());

auto mapping = GetOutputMapping(sop);
auto num_input_columns = input_rel->Columns().size();
if (mapping.empty()) {
for (int i = 1; i <= num_input_columns; i++) {
expressions.push_back(make_uniq<PositionalReferenceExpression>(i));
}

for (auto &sexpr : sop.project().expressions()) {
expressions.push_back(TransformExpr(sexpr, &iterator));
}
} else {
expressions.resize(mapping.size());
for (size_t i = 0; i < mapping.size(); i++) {
if (mapping[i] < num_input_columns) {
expressions[i] = make_uniq<PositionalReferenceExpression>(mapping[i] + 1);
} else {
expressions[i] = TransformExpr(sop.project().expressions(mapping[i] - num_input_columns), &iterator);
}
}
}

vector<string> mock_aliases;
for (size_t i = 0; i < expressions.size(); i++) {
mock_aliases.push_back("expr_" + to_string(i));
}
return make_shared_ptr<ProjectionRelation>(TransformOp(sop.project().input()), std::move(expressions),
std::move(mock_aliases));
return make_shared_ptr<ProjectionRelation>(input_rel, std::move(expressions), std::move(mock_aliases));
}

shared_ptr<Relation> SubstraitToDuckDB::TransformAggregateOp(const substrait::Rel &sop) {
Expand Down
4 changes: 3 additions & 1 deletion src/include/to_substrait.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ class DuckDBToSubstrait {
substrait::Rel *TransformInsertTable(LogicalOperator &dop);
substrait::Rel *TransformDeleteTable(LogicalOperator &dop);
static substrait::Rel *TransformDummyScan();
//! Methods to transform different LogicalGet Types (e.g., Table, Parquet)
static substrait::RelCommon *CreateOutputMapping(vector<int32_t> vector);
//! Methods to transform different LogicalGe:75
//t Types (e.g., Table, Parquet)
//! To Substrait;
void TransformTableScanToSubstrait(LogicalGet &dget, substrait::ReadRel *sget) const;
void TransformParquetScanToSubstrait(LogicalGet &dget, substrait::ReadRel *sget, BindInfo &bind_info,
Expand Down
66 changes: 65 additions & 1 deletion src/to_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -856,14 +856,71 @@ substrait::Rel *DuckDBToSubstrait::TransformFilter(LogicalOperator &dop) {
return res;
}

substrait::RelCommon *DuckDBToSubstrait::CreateOutputMapping(vector<int32_t> vector) {
auto rel_common = new substrait::RelCommon();
auto output_mapping = rel_common->mutable_emit()->mutable_output_mapping();
for (auto &col_idx : vector) {
output_mapping->Add(col_idx);
}
return rel_common;
}

substrait::Rel *DuckDBToSubstrait::TransformProjection(LogicalOperator &dop) {
auto res = new substrait::Rel();
auto &dproj = dop.Cast<LogicalProjection>();

auto child_column_count = dop.children[0]->types.size();
auto num_passthrough_columns = 0;
auto need_output_mapping = true;
if (child_column_count <= dproj.expressions.size()) {
// check if the projection is just pass through of input columns with no reordering
auto exp_col_idx = 0;
auto is_passthrough = true;
for (auto &dexpr : dproj.expressions) {
if (dexpr->type != ExpressionType::BOUND_REF) {
is_passthrough = false;
break;
}
num_passthrough_columns++;
auto &dref = dexpr.get()->Cast<BoundReferenceExpression>();
if (dref.index != exp_col_idx) {
is_passthrough = false;
break;
}
exp_col_idx++;
}
if (is_passthrough && child_column_count == exp_col_idx) {
// skip the projection
return TransformOp(*dop.children[0]);
}
if (child_column_count == exp_col_idx) {
// all input columns are projected, no need for output mapping
num_passthrough_columns = child_column_count;
need_output_mapping = false;
}
}

auto sproj = res->mutable_project();
sproj->set_allocated_input(TransformOp(*dop.children[0]));

auto t_index = 0;
vector<int32_t> output_mapping;
for (auto &dexpr : dproj.expressions) {
TransformExpr(*dexpr, *sproj->add_expressions());
switch (dexpr->type) {
case ExpressionType::BOUND_REF: {
auto &dref = dexpr.get()->Cast<BoundReferenceExpression>();
output_mapping.push_back(dref.index);
break;
}
default:
TransformExpr(*dexpr.get(), *sproj->add_expressions());
output_mapping.push_back(child_column_count + t_index);
t_index++;
}
}
if (need_output_mapping) {
auto rel_common = CreateOutputMapping(output_mapping);
sproj->set_allocated_common(rel_common);
}
return res;
}
Expand Down Expand Up @@ -998,6 +1055,13 @@ substrait::Rel *DuckDBToSubstrait::TransformComparisonJoin(LogicalOperator &dop)
}
}

auto child_column_count = dop.children[0]->types.size() + dop.children[1]->types.size();
vector<int32_t> output_mapping;
for (idx_t i = 0; i < projection->expressions_size(); i++) {
output_mapping.push_back(child_column_count + i);
}
auto rel_common = CreateOutputMapping(output_mapping);
projection->set_allocated_common(rel_common);
projection->set_allocated_input(res);
return proj_rel;
}
Expand Down
2 changes: 1 addition & 1 deletion test/c/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ include_directories(../../duckdb/src/include)
include_directories(../../duckdb/test/include)
include_directories(../../duckdb/third_party/catch)

set(ALL_SOURCES test_substrait_c_api.cpp test_substrait_c_utils.cpp)
set(ALL_SOURCES test_substrait_c_api.cpp test_substrait_c_utils.cpp test_projection.cpp)


add_library_unity(test_substrait OBJECT ${ALL_SOURCES})
Expand Down
136 changes: 136 additions & 0 deletions test/c/test_projection.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
#include "catch.hpp"
#include "test_helpers.hpp"
#include "duckdb/main/connection_manager.hpp"
#include "test_substrait_c_utils.hpp"

#include <chrono>
#include <thread>
#include <iostream>

using namespace duckdb;
using namespace std;

TEST_CASE("Test C Project input columns with Substrait API", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);

REQUIRE_NO_FAIL(con.Query("CREATE TABLE integers(i INTEGER)"));
REQUIRE_NO_FAIL(con.Query("INSERT INTO integers VALUES (10), (20), (30)"));
CreateEmployeeTable(con);

auto expected_json_str = R"({"relations":[{"root":{"input":{"read":{"baseSchema":{"names":["i"],"struct":{"types":[{"i32":{"nullability":"NULLABILITY_NULLABLE"}}],"nullability":"NULLABILITY_REQUIRED"}},"projection":{"select":{"structItems":[{}]},"maintainSingularStruct":true},"namedTable":{"names":["integers"]}}},"names":["i"]}}],"version":{"minorNumber":53,"producer":"DuckDB"}})";
auto json_str = con.GetSubstraitJSON("SELECT i FROM integers");
REQUIRE(json_str == expected_json_str);
auto result = con.FromSubstraitJSON(json_str);
REQUIRE(CHECK_COLUMN(result, 0, {10, 20, 30}));
}

TEST_CASE("Test C Project 1 input column 1 transformation with Substrait API", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);

REQUIRE_NO_FAIL(con.Query("CREATE TABLE integers(i INTEGER)"));
REQUIRE_NO_FAIL(con.Query("INSERT INTO integers VALUES (10), (20), (30)"));
CreateEmployeeTable(con);

auto expected_json_str = R"({"extensionUris":[{"extensionUriAnchor":1,"uri":"https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml"}],"extensions":[{"extensionFunction":{"extensionUriReference":1,"functionAnchor":1,"name":"multiply:i32_i32"}}],"relations":[{"root":{"input":{"project":{"input":{"read":{"baseSchema":{"names":["i"],"struct":{"types":[{"i32":{"nullability":"NULLABILITY_NULLABLE"}}],"nullability":"NULLABILITY_REQUIRED"}},"projection":{"select":{"structItems":[{}]},"maintainSingularStruct":true},"namedTable":{"names":["integers"]}}},"expressions":[{"scalarFunction":{"functionReference":1,"outputType":{"i32":{"nullability":"NULLABILITY_NULLABLE"}},"arguments":[{"value":{"selection":{"directReference":{"structField":{}},"rootReference":{}}}},{"value":{"selection":{"directReference":{"structField":{}},"rootReference":{}}}}]}}]}},"names":["i","isquare"]}}],"version":{"minorNumber":53,"producer":"DuckDB"}})";
auto json_str = con.GetSubstraitJSON("SELECT i, i *i as isquare FROM integers");
REQUIRE(json_str == expected_json_str);
auto result = con.FromSubstraitJSON(json_str);
REQUIRE(CHECK_COLUMN(result, 0, {10, 20, 30}));
REQUIRE(CHECK_COLUMN(result, 1, {100, 400, 900}));
}

TEST_CASE("Test C Project all columns with Substrait API", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);

CreateEmployeeTable(con);

// This should not have a ProjectRel node
auto json_str = con.GetSubstraitJSON("SELECT * FROM employees");
auto expected_json_str = R"({"relations":[{"root":{"input":{"read":{"baseSchema":{"names":["employee_id","name","department_id","salary"],"struct":{"types":[{"i32":{"nullability":"NULLABILITY_REQUIRED"}},{"string":{"nullability":"NULLABILITY_NULLABLE"}},{"i32":{"nullability":"NULLABILITY_NULLABLE"}},{"decimal":{"scale":2,"precision":10,"nullability":"NULLABILITY_NULLABLE"}}],"nullability":"NULLABILITY_REQUIRED"}},"projection":{"select":{"structItems":[{},{"field":1},{"field":2},{"field":3}]},"maintainSingularStruct":true},"namedTable":{"names":["employees"]}}},"names":["employee_id","name","department_id","salary"]}}],"version":{"minorNumber":53,"producer":"DuckDB"}})";
REQUIRE(json_str == expected_json_str);
auto result = con.FromSubstraitJSON(json_str);
REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 3, 4, 5}));
REQUIRE(CHECK_COLUMN(result, 1, {"John Doe", "Jane Smith", "Alice Johnson", "Bob Brown", "Charlie Black"}));
REQUIRE(CHECK_COLUMN(result, 2, {1, 2, 1, 3, 2}));
REQUIRE(CHECK_COLUMN(result, 3, {120000, 80000, 50000, 95000, 60000}));
}

TEST_CASE("Test C Project two passthrough columns with Substrait API", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);

CreateEmployeeTable(con);

// This should not have a ProjectRel node
auto json_str = con.GetSubstraitJSON("SELECT name, salary FROM employees");
auto expected_json_str = R"({"relations":[{"root":{"input":{"read":{"baseSchema":{"names":["employee_id","name","department_id","salary"],"struct":{"types":[{"i32":{"nullability":"NULLABILITY_REQUIRED"}},{"string":{"nullability":"NULLABILITY_NULLABLE"}},{"i32":{"nullability":"NULLABILITY_NULLABLE"}},{"decimal":{"scale":2,"precision":10,"nullability":"NULLABILITY_NULLABLE"}}],"nullability":"NULLABILITY_REQUIRED"}},"projection":{"select":{"structItems":[{"field":1},{"field":3}]},"maintainSingularStruct":true},"namedTable":{"names":["employees"]}}},"names":["name","salary"]}}],"version":{"minorNumber":53,"producer":"DuckDB"}})";
REQUIRE(json_str == expected_json_str);
auto result = con.FromSubstraitJSON(json_str);
REQUIRE(CHECK_COLUMN(result, 0, {"John Doe", "Jane Smith", "Alice Johnson", "Bob Brown", "Charlie Black"}));
REQUIRE(CHECK_COLUMN(result, 1, {120000, 80000, 50000, 95000, 60000}));
}

TEST_CASE("Test C Project two passthrough columns with filter", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);

CreateEmployeeTable(con);

// This should not have a ProjectRel node
auto json_str = con.GetSubstraitJSON("SELECT name, salary FROM employees where department_id = 1");
auto expected_json_str = R"({"extensionUris":[{"extensionUriAnchor":1,"uri":"https://github.com/substrait-io/substrait/blob/main/extensions/"}],"extensions":[{"extensionFunction":{"extensionUriReference":1,"functionAnchor":1,"name":"equal:i32_i32"}}],"relations":[{"root":{"input":{"read":{"baseSchema":{"names":["employee_id","name","department_id","salary"],"struct":{"types":[{"i32":{"nullability":"NULLABILITY_REQUIRED"}},{"string":{"nullability":"NULLABILITY_NULLABLE"}},{"i32":{"nullability":"NULLABILITY_NULLABLE"}},{"decimal":{"scale":2,"precision":10,"nullability":"NULLABILITY_NULLABLE"}}],"nullability":"NULLABILITY_REQUIRED"}},"filter":{"scalarFunction":{"functionReference":1,"outputType":{"bool":{"nullability":"NULLABILITY_NULLABLE"}},"arguments":[{"value":{"selection":{"directReference":{"structField":{"field":2}},"rootReference":{}}}},{"value":{"literal":{"i32":1}}}]}},"projection":{"select":{"structItems":[{"field":1},{"field":3}]},"maintainSingularStruct":true},"namedTable":{"names":["employees"]}}},"names":["name","salary"]}}],"version":{"minorNumber":53,"producer":"DuckDB"}})";
REQUIRE(json_str == expected_json_str);
auto result = con.FromSubstraitJSON(json_str);
REQUIRE(CHECK_COLUMN(result, 0, {"John Doe", "Alice Johnson" }));
REQUIRE(CHECK_COLUMN(result, 1, {120000, 50000 }));
}

TEST_CASE("Test C Project 1 passthrough column, 1 transformation with column elimination", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);

CreateEmployeeTable(con);

auto json_str = con.GetSubstraitJSON("SELECT name, salary * 1.2 as new_salary FROM employees");
auto expected_json_str = R"({"extensionUris":[{"extensionUriAnchor":1,"uri":"https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic_decimal.yaml"}],"extensions":[{"extensionFunction":{"extensionUriReference":1,"functionAnchor":1,"name":"multiply:decimal_decimal"}}],"relations":[{"root":{"input":{"project":{"common":{"emit":{"outputMapping":[0,2]}},"input":{"read":{"baseSchema":{"names":["employee_id","name","department_id","salary"],"struct":{"types":[{"i32":{"nullability":"NULLABILITY_REQUIRED"}},{"string":{"nullability":"NULLABILITY_NULLABLE"}},{"i32":{"nullability":"NULLABILITY_NULLABLE"}},{"decimal":{"scale":2,"precision":10,"nullability":"NULLABILITY_NULLABLE"}}],"nullability":"NULLABILITY_REQUIRED"}},"projection":{"select":{"structItems":[{"field":1},{"field":3}]},"maintainSingularStruct":true},"namedTable":{"names":["employees"]}}},"expressions":[{"scalarFunction":{"functionReference":1,"outputType":{"decimal":{"scale":3,"precision":12,"nullability":"NULLABILITY_NULLABLE"}},"arguments":[{"value":{"selection":{"directReference":{"structField":{"field":1}},"rootReference":{}}}},{"value":{"literal":{"decimal":{"value":"DAAAAAAAAAAAAAAAAAAAAA==","precision":12,"scale":1}}}}]}}]}},"names":["name","new_salary"]}}],"version":{"minorNumber":53,"producer":"DuckDB"}})";
REQUIRE(json_str == expected_json_str);
auto result = con.FromSubstraitJSON(json_str);
REQUIRE(CHECK_COLUMN(result, 0, {"John Doe", "Jane Smith", "Alice Johnson", "Bob Brown", "Charlie Black"}));
REQUIRE(CHECK_COLUMN(result, 1, {144000, 96000, 60000, 114000, 72000}));
}

TEST_CASE("Test C Project 1 passthrough column and 1 aggregate transformation", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);

CreateEmployeeTable(con);

auto json_str = con.GetSubstraitJSON("SELECT department_id, AVG(salary) AS avg_salary FROM employees GROUP BY department_id");
auto expected_json_str = R"({"extensionUris":[{"extensionUriAnchor":1,"uri":"https://github.com/substrait-io/substrait/blob/main/extensions/"}],"extensions":[{"extensionFunction":{"extensionUriReference":1,"functionAnchor":1,"name":"avg:decimal"}}],"relations":[{"root":{"input":{"aggregate":{"input":{"read":{"baseSchema":{"names":["employee_id","name","department_id","salary"],"struct":{"types":[{"i32":{"nullability":"NULLABILITY_REQUIRED"}},{"string":{"nullability":"NULLABILITY_NULLABLE"}},{"i32":{"nullability":"NULLABILITY_NULLABLE"}},{"decimal":{"scale":2,"precision":10,"nullability":"NULLABILITY_NULLABLE"}}],"nullability":"NULLABILITY_REQUIRED"}},"projection":{"select":{"structItems":[{"field":2},{"field":3}]},"maintainSingularStruct":true},"namedTable":{"names":["employees"]}}},"groupings":[{"groupingExpressions":[{"selection":{"directReference":{"structField":{}},"rootReference":{}}}]}],"measures":[{"measure":{"functionReference":1,"outputType":{"fp64":{"nullability":"NULLABILITY_NULLABLE"}},"arguments":[{"value":{"selection":{"directReference":{"structField":{"field":1}},"rootReference":{}}}}]}}]}},"names":["department_id","avg_salary"]}}],"version":{"minorNumber":53,"producer":"DuckDB"}})";
REQUIRE(json_str == expected_json_str);
auto result = con.FromSubstraitJSON(json_str);
REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 3}));
REQUIRE(CHECK_COLUMN(result, 1, {85000, 70000, 95000}));
}

TEST_CASE("Test C Project on Join with Substrait API", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);

CreateEmployeeTable(con);
CreateDepartmentsTable(con);

auto result = ExecuteViaSubstraitJSON(con,
"SELECT e.employee_id, e.name, d.department_name "
"FROM employees e "
"JOIN departments d "
"ON e.department_id = d.department_id"
);

REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 3, 4, 5}));
REQUIRE(CHECK_COLUMN(result, 1, {"John Doe", "Jane Smith", "Alice Johnson", "Bob Brown", "Charlie Black"}));
REQUIRE(CHECK_COLUMN(result, 2, {"HR", "Engineering", "HR", "Finance", "Engineering"}));
}

0 comments on commit b11ab6f

Please sign in to comment.