Skip to content

Commit

Permalink
Merge pull request #107 from pdet/substrait_regressions
Browse files Browse the repository at this point in the history
Make projection finding function optional
  • Loading branch information
pdet authored Sep 10, 2024
2 parents 800be49 + 62316a6 commit f35aa93
Show file tree
Hide file tree
Showing 4 changed files with 2,404 additions and 36 deletions.
58 changes: 26 additions & 32 deletions src/from_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,31 +331,30 @@ unique_ptr<ParsedExpression> SubstraitToDuckDB::TransformNested(const substrait:
auto &nested_expression = sexpr.nested();
if (nested_expression.has_struct_()) {
auto &struct_expression = nested_expression.struct_();
vector<unique_ptr<ParsedExpression>> children;
for (auto& child: struct_expression.fields()) {
vector<unique_ptr<ParsedExpression>> children;
for (auto &child : struct_expression.fields()) {
children.emplace_back(TransformExpr(child));
}
return make_uniq<FunctionExpression>("row", std::move(children));
} else if (nested_expression.has_list()) {
auto &list_expression = nested_expression.list();
vector<unique_ptr<ParsedExpression>> children;
for (auto& child: list_expression.values()) {
vector<unique_ptr<ParsedExpression>> children;
for (auto &child : list_expression.values()) {
children.emplace_back(TransformExpr(child));
}
return make_uniq<FunctionExpression>("list_value", std::move(children));

} else if (nested_expression.has_map()) {
auto &map_expression = nested_expression.map();
vector<unique_ptr<ParsedExpression>> children;
vector<unique_ptr<ParsedExpression>> children;
auto key_value = map_expression.key_values();
children.emplace_back(TransformExpr(key_value[0].key()));
children.emplace_back(TransformExpr(key_value[0].value()));
return make_uniq<FunctionExpression>("map", std::move(children));

} else{
} else {
throw NotImplementedException("Substrait nested expression is not yet implemented.");
}

}

unique_ptr<ParsedExpression> SubstraitToDuckDB::TransformExpr(const substrait::Expression &sexpr) {
Expand Down Expand Up @@ -663,25 +662,18 @@ int32_t SkipColumnNames(const LogicalType &type) {
return columns_to_skip;
}

Relation *GetProjectionOrTableRelation(Relation &relation, string &error) {
error += RelationTypeToString(relation.type);
Relation *GetProjection(Relation &relation) {
switch (relation.type) {
case RelationType::TABLE_RELATION:
case RelationType::PROJECTION_RELATION:
error += " -> ";
return &relation;
case RelationType::LIMIT_RELATION:
error += " -> ";
return GetProjectionOrTableRelation(*relation.Cast<LimitRelation>().child, error);
return GetProjection(*relation.Cast<LimitRelation>().child);
case RelationType::ORDER_RELATION:
error += " -> ";
return GetProjectionOrTableRelation(*relation.Cast<OrderRelation>().child, error);
return GetProjection(*relation.Cast<OrderRelation>().child);
case RelationType::SET_OPERATION_RELATION:
error += " -> ";
return GetProjectionOrTableRelation(*relation.Cast<SetOpRelation>().right, error);
return GetProjection(*relation.Cast<SetOpRelation>().right);
default:
throw NotImplementedException(
"Relation %s is not yet implemented as a possible root chain type of from_substrait function", error);
return nullptr;
}
}

Expand All @@ -691,21 +683,23 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformRootOp(const substrait::RelRoot
vector<unique_ptr<ParsedExpression>> expressions;
int id = 1;
auto child = TransformOp(sop.input());
string error;
auto first_projection_or_table = GetProjectionOrTableRelation(*child, error);
vector<ColumnDefinition> *column_definitions;
if (first_projection_or_table->type == RelationType::PROJECTION_RELATION) {
column_definitions = &first_projection_or_table->Cast<ProjectionRelation>().columns;
auto first_projection_or_table = GetProjection(*child);
if (first_projection_or_table) {
vector<ColumnDefinition> *column_definitions = &first_projection_or_table->Cast<ProjectionRelation>().columns;
int32_t i = 0;
for (auto &column : *column_definitions) {
aliases.push_back(column_names[i++]);
auto column_type = column.GetType();
i += SkipColumnNames(column.GetType());
expressions.push_back(make_uniq<PositionalReferenceExpression>(id++));
}
} else {
column_definitions = &first_projection_or_table->Cast<TableRelation>().description->columns;
}
int32_t i = 0;
for (auto &column : *column_definitions) {
aliases.push_back(column_names[i++]);
auto column_type = column.GetType();
i += SkipColumnNames(column.GetType());
expressions.push_back(make_uniq<PositionalReferenceExpression>(id++));
for (auto &column_name : column_names) {
aliases.push_back(column_name);
expressions.push_back(make_uniq<PositionalReferenceExpression>(id++));
}
}

return make_shared_ptr<ProjectionRelation>(child, std::move(expressions), aliases);
}

Expand Down
2 changes: 1 addition & 1 deletion src/include/from_substrait.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,6 @@ class SubstraitToDuckDB {
//! names
static const unordered_map<std::string, std::string> function_names_remap;
static const case_insensitive_set_t valid_extract_subfields;
vector<ParsedExpression*> struct_expressions;
vector<ParsedExpression *> struct_expressions;
};
} // namespace duckdb
5 changes: 2 additions & 3 deletions src/to_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,13 +315,12 @@ void DuckDBToSubstrait::TransformFunctionExpression(Expression &dexpr, substrait
uint64_t col_offset) {
auto &dfun = dexpr.Cast<BoundFunctionExpression>();


auto function_name = dfun.function.name;

if (function_name == "row") {
auto nested_expression = sexpr.mutable_nested();
auto struct_expression = nested_expression->mutable_struct_();
for (auto& child: dfun.children) {
for (auto &child : dfun.children) {
auto child_expression = struct_expression->add_fields();
TransformExpr(*child, *child_expression);
}
Expand All @@ -330,7 +329,7 @@ void DuckDBToSubstrait::TransformFunctionExpression(Expression &dexpr, substrait
if (function_name == "list_value" || function_name == "list_pack") {
auto nested_expression = sexpr.mutable_nested();
auto list_expression = nested_expression->mutable_list();
for (auto& child: dfun.children) {
for (auto &child : dfun.children) {
auto child_value = list_expression->add_values();
TransformExpr(*child, *child_value);
}
Expand Down
Loading

0 comments on commit f35aa93

Please sign in to comment.