Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: project rel to and from substrait to include pass through columns #135

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
70 changes: 60 additions & 10 deletions src/from_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@
interval_t interval {};
interval.months = 0;
interval.days = literal.interval_day_to_second().days();
interval.micros = literal.interval_day_to_second().microseconds();

Check warning on line 165 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'microseconds' is deprecated [-Wdeprecated-declarations]

Check warning on line 165 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'microseconds' is deprecated [-Wdeprecated-declarations]

Check warning on line 165 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'microseconds' is deprecated [-Wdeprecated-declarations]

Check warning on line 165 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'microseconds' is deprecated [-Wdeprecated-declarations]
return Value::INTERVAL(interval);
}
default:
Expand Down Expand Up @@ -492,22 +492,59 @@
return make_shared_ptr<FilterRelation>(TransformOp(sfilter.input()), TransformExpr(sfilter.condition()));
}

const google::protobuf::RepeatedField<int32_t>& GetOutputMapping(const substrait::Rel &sop) {
const substrait::RelCommon* common = nullptr;
switch (sop.rel_type_case()) {
case substrait::Rel::RelTypeCase::kJoin:
common = &sop.join().common();
break;
case substrait::Rel::RelTypeCase::kProject:
common = &sop.project().common();
break;
default:
throw InternalException("Unsupported relation type " + to_string(sop.rel_type_case()));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost every relation should be in this list. I'd consider calling anything missing not yet implemented.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah..! Thanks for pointing out, I missed this.

This also reminds me that I could avoid a project by using output mapping, whenever I only have to change the column order of a give relation.

}
if (!common->has_emit()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be useful to break this into code to get the common table and code to get the emit from the mapping. That way you can use .the common structure to update direct and emit if necessary. You might be able to use templates to get the RelCommon which could reduce the overall amount of code.

static google::protobuf::RepeatedField<int32_t> empty_mapping;
return empty_mapping;
}
return common->emit().output_mapping();
}

shared_ptr<Relation>
SubstraitToDuckDB::TransformProjectOp(const substrait::Rel &sop,
const google::protobuf::RepeatedPtrField<std::string> *names) {
vector<unique_ptr<ParsedExpression>> expressions;
RootNameIterator iterator(names);

for (auto &sexpr : sop.project().expressions()) {
expressions.push_back(TransformExpr(sexpr, &iterator));
auto input_rel = TransformOp(sop.project().input());

auto mapping = GetOutputMapping(sop);
auto num_input_columns = input_rel->Columns().size();
if (mapping.empty()) {
for (int i = 1; i <= num_input_columns; i++) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first column is numbered at zero. I found https://substrait.io/tutorial/sql_to_substrait/#field-indices to be useful (in addition to the individual relation pages).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The index i is used to create the duckdb column reference not substrait. For reference please see here.

expressions.push_back(make_uniq<PositionalReferenceExpression>(i));
}

for (auto &sexpr : sop.project().expressions()) {
expressions.push_back(TransformExpr(sexpr, &iterator));
}
} else {
expressions.resize(mapping.size());
for (size_t i = 0; i < mapping.size(); i++) {
if (mapping[i] < num_input_columns) {
expressions[i] = make_uniq<PositionalReferenceExpression>(mapping[i] + 1);
} else {
expressions[i] = TransformExpr(sop.project().expressions(mapping[i] - num_input_columns), &iterator);
}
}
}

vector<string> mock_aliases;
for (size_t i = 0; i < expressions.size(); i++) {
mock_aliases.push_back("expr_" + to_string(i));
}
return make_shared_ptr<ProjectionRelation>(TransformOp(sop.project().input()), std::move(expressions),
std::move(mock_aliases));
return make_shared_ptr<ProjectionRelation>(input_rel, std::move(expressions), std::move(mock_aliases));
}

shared_ptr<Relation> SubstraitToDuckDB::TransformAggregateOp(const substrait::Rel &sop) {
Expand All @@ -515,7 +552,7 @@

if (sop.aggregate().groupings_size() > 0) {
for (auto &sgrp : sop.aggregate().groupings()) {
for (auto &sgrpexpr : sgrp.grouping_expressions()) {

Check warning on line 555 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'grouping_expressions' is deprecated [-Wdeprecated-declarations]

Check warning on line 555 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'grouping_expressions' is deprecated [-Wdeprecated-declarations]
groups.push_back(TransformExpr(sgrpexpr));
expressions.push_back(TransformExpr(sgrpexpr));
}
Expand Down Expand Up @@ -615,8 +652,8 @@
scan = rel->Alias(name);
} else if (sget.has_virtual_table()) {
// We need to handle a virtual table as a LogicalExpressionGet
if (!sget.virtual_table().values().empty()) {

Check warning on line 655 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'values' is deprecated [-Wdeprecated-declarations]

Check warning on line 655 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'values' is deprecated [-Wdeprecated-declarations]
auto literal_values = sget.virtual_table().values();

Check warning on line 656 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'values' is deprecated [-Wdeprecated-declarations]

Check warning on line 656 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'values' is deprecated [-Wdeprecated-declarations]
vector<vector<Value>> expression_rows;
for (auto &row : literal_values) {
auto values = row.fields();
Expand Down Expand Up @@ -739,15 +776,25 @@
}
auto input = TransformOp(swrite.input());
switch (swrite.op()) {
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_CTAS:
return input->CreateRel(schema_name, table_name);
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_CTAS:
return input->CreateRel(schema_name, table_name);
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_INSERT:
return input->InsertRel(schema_name, table_name);
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_DELETE: {
auto filter = std::move(input.get()->Cast<FilterRelation>());
auto context = filter.child->Cast<TableRelation>().context;
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_DELETE: {
switch (input->type) {
case RelationType::PROJECTION_RELATION: {
auto project = std::move(input.get()->Cast<ProjectionRelation>());
auto filter = std::move(project.child->Cast<FilterRelation>());
return make_shared_ptr<DeleteRelation>(filter.context, std::move(filter.condition), schema_name, table_name);
}
}
case RelationType::FILTER_RELATION: {
auto filter = std::move(input.get()->Cast<FilterRelation>());
return make_shared_ptr<DeleteRelation>(filter.context, std::move(filter.condition), schema_name, table_name);
}
default:
throw NotImplementedException("Unsupported relation type for delete operation");
}
}
default:
throw NotImplementedException("Unsupported write operation " + to_string(swrite.op()));
}
Expand Down Expand Up @@ -822,6 +869,9 @@
if (first_projection_or_table) {
vector<ColumnDefinition> *column_definitions = &first_projection_or_table->Cast<ProjectionRelation>().columns;
int32_t i = 0;
if (column_definitions->size() > column_names.size()) {
throw InvalidInputException("Number of column names less than number of column definitions");
}
for (auto &column : *column_definitions) {
aliases.push_back(column_names[i++]);
auto column_type = column.GetType();
Expand Down
5 changes: 4 additions & 1 deletion src/include/to_substrait.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,11 @@ class DuckDBToSubstrait {
substrait::Rel *TransformCreateTable(LogicalOperator &dop);
substrait::Rel *TransformInsertTable(LogicalOperator &dop);
substrait::Rel *TransformDeleteTable(LogicalOperator &dop);
static vector<LogicalType>::size_type GetColumnCount(LogicalOperator &dop);
static substrait::Rel *TransformDummyScan();
//! Methods to transform different LogicalGet Types (e.g., Table, Parquet)
static substrait::RelCommon *CreateOutputMapping(vector<int32_t> vector);
//! Methods to transform different LogicalGe:75
//t Types (e.g., Table, Parquet)
//! To Substrait;
void TransformTableScanToSubstrait(LogicalGet &dget, substrait::ReadRel *sget) const;
void TransformParquetScanToSubstrait(LogicalGet &dget, substrait::ReadRel *sget, BindInfo &bind_info,
Expand Down
103 changes: 100 additions & 3 deletions src/to_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@
} else {
auto interval_day = make_uniq<substrait::Expression_Literal_IntervalDayToSecond>();
interval_day->set_days(dval.GetValue<interval_t>().days);
interval_day->set_microseconds(static_cast<int32_t>(dval.GetValue<interval_t>().micros));

Check warning on line 204 in src/to_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'set_microseconds' is deprecated [-Wdeprecated-declarations]

Check warning on line 204 in src/to_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'set_microseconds' is deprecated [-Wdeprecated-declarations]
sval.set_allocated_interval_day_to_second(interval_day.release());
}
}
Expand Down Expand Up @@ -847,23 +847,88 @@

if (!dfilter.projection_map.empty()) {
auto projection = new substrait::Rel();
projection->mutable_project()->set_allocated_input(res);
auto sproj = projection->mutable_project();
sproj->set_allocated_input(res);
auto child_column_count = GetColumnCount(*dop.children[0]);
auto t_index = 0;
vector<int32_t> output_mapping;
for (auto col_idx : dfilter.projection_map) {
CreateFieldRef(projection->mutable_project()->add_expressions(), col_idx);
CreateFieldRef(sproj->add_expressions(), col_idx);
output_mapping.push_back(child_column_count + t_index);
}
auto rel_common = CreateOutputMapping(output_mapping);
sproj->set_allocated_common(rel_common);
res = projection;
}
return res;
}

substrait::RelCommon *DuckDBToSubstrait::CreateOutputMapping(vector<int32_t> vector) {
auto rel_common = new substrait::RelCommon();
auto output_mapping = rel_common->mutable_emit()->mutable_output_mapping();
for (auto &col_idx : vector) {
output_mapping->Add(col_idx);
}
return rel_common;
}

substrait::Rel *DuckDBToSubstrait::TransformProjection(LogicalOperator &dop) {
auto res = new substrait::Rel();
auto &dproj = dop.Cast<LogicalProjection>();

auto child_column_count = GetColumnCount(*dop.children[0]);
auto need_output_mapping = true;
if (child_column_count <= dproj.expressions.size()) {
// check if the projection is just pass through of input columns with no reordering
auto exp_col_idx = 0;
auto is_passthrough = true;
for (auto &dexpr : dproj.expressions) {
if (dexpr->type != ExpressionType::BOUND_REF) {
is_passthrough = false;
break;
}
auto &dref = dexpr.get()->Cast<BoundReferenceExpression>();
if (dref.index != exp_col_idx) {
is_passthrough = false;
break;
}
exp_col_idx++;
}
if (is_passthrough && child_column_count == exp_col_idx) {
// skip the projection
return TransformOp(*dop.children[0]);
}
if (child_column_count == exp_col_idx) {
// all input columns are projected, no need for output mapping
need_output_mapping = false;
}
}

auto sproj = res->mutable_project();
sproj->set_allocated_input(TransformOp(*dop.children[0]));

auto t_index = 0;
vector<int32_t> output_mapping;
for (auto &dexpr : dproj.expressions) {
TransformExpr(*dexpr, *sproj->add_expressions());
switch (dexpr->type) {
case ExpressionType::BOUND_REF: {
auto &dref = dexpr.get()->Cast<BoundReferenceExpression>();
output_mapping.push_back(dref.index);
break;
}
default:
TransformExpr(*dexpr.get(), *sproj->add_expressions());
output_mapping.push_back(child_column_count + t_index);
t_index++;
}
}
if (need_output_mapping) {
if (sproj->expressions_size() == 0) {
// atleast one expression should be there, add zeroth column as dummy expression
CreateFieldRef(sproj->add_expressions(), 0);
}
auto rel_common = CreateOutputMapping(output_mapping);
sproj->set_allocated_common(rel_common);
}
return res;
}
Expand Down Expand Up @@ -987,17 +1052,26 @@
djoin.right_projection_map.push_back(i);
}
}
// TODO this projection seems redundant but from_substrait does not work without it
auto proj_rel = new substrait::Rel();
auto projection = proj_rel->mutable_project();
auto child_column_count = GetColumnCount(*dop.children[0]);
for (auto left_idx : djoin.left_projection_map) {
CreateFieldRef(projection->add_expressions(), left_idx);
}
if (djoin.join_type != JoinType::SEMI) {
child_column_count += GetColumnCount(*dop.children[1]);
for (auto right_idx : djoin.right_projection_map) {
CreateFieldRef(projection->add_expressions(), right_idx + left_col_count);
}
}

vector<int32_t> output_mapping;
for (idx_t i = 0; i < projection->expressions_size(); i++) {
output_mapping.push_back(child_column_count + i);
}
auto rel_common = CreateOutputMapping(output_mapping);
projection->set_allocated_common(rel_common);
projection->set_allocated_input(res);
return proj_rel;
}
Expand All @@ -1014,7 +1088,7 @@
// TODO push projection or push substrait to allow expressions here
throw NotImplementedException("No expressions in groupings yet");
}
TransformExpr(*dgrp, *sgrp->add_grouping_expressions());

Check warning on line 1091 in src/to_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'add_grouping_expressions' is deprecated [-Wdeprecated-declarations]

Check warning on line 1091 in src/to_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'add_grouping_expressions' is deprecated [-Wdeprecated-declarations]
}
for (auto &dmeas : daggr.expressions) {
auto smeas = saggr->add_measures()->mutable_measure();
Expand Down Expand Up @@ -1282,7 +1356,7 @@
auto virtual_table = sget->mutable_virtual_table();

// Add a dummy value to emit one row
auto dummy_value = virtual_table->add_values();

Check warning on line 1359 in src/to_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'add_values' is deprecated [-Wdeprecated-declarations]

Check warning on line 1359 in src/to_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'add_values' is deprecated [-Wdeprecated-declarations]
dummy_value->add_fields()->set_i32(42);
return get_rel;
}
Expand Down Expand Up @@ -1324,6 +1398,25 @@
}
projection->set_allocated_select(select);
sget->set_allocated_projection(projection);
} else if (!dget.GetColumnIds().empty()) {
auto &column_ids = dget.GetColumnIds();
vector<int> column_indices;
for (auto &column_id : column_ids) {
if (!column_id.IsRowIdColumn()) {
column_indices.push_back(column_id.GetPrimaryIndex());
}
}
if (!column_indices.empty() && column_indices.size() < dget.returned_types.size()) {
auto projection = new substrait::Expression_MaskExpression();
projection->set_maintain_singular_struct(true);
auto select = new substrait::Expression_MaskExpression_StructSelect();
for (auto col_idx : column_indices) {
auto struct_item = select->add_struct_items();
struct_item->set_field(static_cast<int32_t>(col_idx));
}
projection->set_allocated_select(select);
sget->set_allocated_projection(projection);
}
}

// Add Table Schema
Expand Down Expand Up @@ -1540,6 +1633,10 @@
return rel;
}

vector<LogicalType>::size_type DuckDBToSubstrait::GetColumnCount(LogicalOperator &dop) {
return dop.types.size();
}

substrait::Rel *DuckDBToSubstrait::TransformOp(LogicalOperator &dop) {
switch (dop.type) {
case LogicalOperatorType::LOGICAL_FILTER:
Expand Down
2 changes: 1 addition & 1 deletion test/c/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ include_directories(../../duckdb/src/include)
include_directories(../../duckdb/test/include)
include_directories(../../duckdb/third_party/catch)

set(ALL_SOURCES test_substrait_c_api.cpp)
set(ALL_SOURCES test_substrait_c_api.cpp test_substrait_c_utils.cpp test_projection.cpp)


add_library_unity(test_substrait OBJECT ${ALL_SOURCES})
Expand Down
Loading
Loading