Skip to content

Commit

Permalink
Add support for structs from nested functions if we can refeer to ori…
Browse files Browse the repository at this point in the history
…ginal naming from root
  • Loading branch information
pdet committed Sep 24, 2024
1 parent 3fa6609 commit d488993
Show file tree
Hide file tree
Showing 3 changed files with 237 additions and 26 deletions.
58 changes: 40 additions & 18 deletions src/from_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -337,15 +337,25 @@ unique_ptr<ParsedExpression> SubstraitToDuckDB::TransformInExpr(const substrait:
return make_uniq<OperatorExpression>(ExpressionType::COMPARE_IN, std::move(values));
}

unique_ptr<ParsedExpression> SubstraitToDuckDB::TransformNested(const substrait::Expression &sexpr) {
unique_ptr<ParsedExpression> SubstraitToDuckDB::TransformNested(const substrait::Expression &sexpr,
RootNameIterator *iterator) {
auto &nested_expression = sexpr.nested();
if (nested_expression.has_struct_()) {
auto &struct_expression = nested_expression.struct_();
vector<unique_ptr<ParsedExpression>> children;
for (auto &child : struct_expression.fields()) {
children.emplace_back(TransformExpr(child));
}
return make_uniq<FunctionExpression>("row", std::move(children));
if (iterator && !iterator->Finished() && iterator->Unique(children.size())) {
for (auto &child : children) {
child->alias = iterator->GetCurrentName();
iterator->Next();
}
return make_uniq<FunctionExpression>("struct_pack", std::move(children));
} else {
return make_uniq<FunctionExpression>("row", std::move(children));
}

} else if (nested_expression.has_list()) {
auto &list_expression = nested_expression.list();
vector<unique_ptr<ParsedExpression>> children;
Expand All @@ -367,7 +377,11 @@ unique_ptr<ParsedExpression> SubstraitToDuckDB::TransformNested(const substrait:
}
}

unique_ptr<ParsedExpression> SubstraitToDuckDB::TransformExpr(const substrait::Expression &sexpr) {
unique_ptr<ParsedExpression> SubstraitToDuckDB::TransformExpr(const substrait::Expression &sexpr,
RootNameIterator *iterator) {
if (iterator) {
iterator->Next();
}
switch (sexpr.rex_type_case()) {
case substrait::Expression::RexTypeCase::kLiteral:
return TransformLiteralExpr(sexpr);
Expand All @@ -382,7 +396,7 @@ unique_ptr<ParsedExpression> SubstraitToDuckDB::TransformExpr(const substrait::E
case substrait::Expression::RexTypeCase::kSingularOrList:
return TransformInExpr(sexpr);
case substrait::Expression::RexTypeCase::kNested:
return TransformNested(sexpr);
return TransformNested(sexpr, iterator);
case substrait::Expression::RexTypeCase::kSubquery:
default:
throw InternalException("Unsupported expression type " + to_string(sexpr.rex_type_case()));
Expand Down Expand Up @@ -464,22 +478,27 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformCrossProductOp(const substrait:
TransformOp(sub_cross.right())->Alias("right"));
}

shared_ptr<Relation> SubstraitToDuckDB::TransformFetchOp(const substrait::Rel &sop) {
shared_ptr<Relation> SubstraitToDuckDB::TransformFetchOp(const substrait::Rel &sop,
const google::protobuf::RepeatedPtrField<std::string> *names) {
auto &slimit = sop.fetch();
idx_t limit = slimit.count() == -1 ? NumericLimits<idx_t>::Maximum() : slimit.count();
idx_t offset = slimit.offset();
return make_shared_ptr<LimitRelation>(TransformOp(slimit.input()), limit, offset);
return make_shared_ptr<LimitRelation>(TransformOp(slimit.input(), names), limit, offset);
}

shared_ptr<Relation> SubstraitToDuckDB::TransformFilterOp(const substrait::Rel &sop) {
auto &sfilter = sop.filter();
return make_shared_ptr<FilterRelation>(TransformOp(sfilter.input()), TransformExpr(sfilter.condition()));
}

shared_ptr<Relation> SubstraitToDuckDB::TransformProjectOp(const substrait::Rel &sop) {
shared_ptr<Relation>
SubstraitToDuckDB::TransformProjectOp(const substrait::Rel &sop,
const google::protobuf::RepeatedPtrField<std::string> *names) {
vector<unique_ptr<ParsedExpression>> expressions;
RootNameIterator iterator(names);

for (auto &sexpr : sop.project().expressions()) {
expressions.push_back(TransformExpr(sexpr));
expressions.push_back(TransformExpr(sexpr, &iterator));
}

vector<string> mock_aliases;
Expand Down Expand Up @@ -616,12 +635,13 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformReadOp(const substrait::Rel &so
return scan;
}

shared_ptr<Relation> SubstraitToDuckDB::TransformSortOp(const substrait::Rel &sop) {
shared_ptr<Relation> SubstraitToDuckDB::TransformSortOp(const substrait::Rel &sop,
const google::protobuf::RepeatedPtrField<std::string> *names) {
vector<OrderByNode> order_nodes;
for (auto &sordf : sop.sort().sorts()) {
order_nodes.push_back(TransformOrder(sordf));
}
return make_shared_ptr<OrderRelation>(TransformOp(sop.sort().input()), std::move(order_nodes));
return make_shared_ptr<OrderRelation>(TransformOp(sop.sort().input(), names), std::move(order_nodes));
}

static SetOperationType TransformSetOperationType(substrait::SetRel_SetOp setop) {
Expand All @@ -641,7 +661,8 @@ static SetOperationType TransformSetOperationType(substrait::SetRel_SetOp setop)
}
}

shared_ptr<Relation> SubstraitToDuckDB::TransformSetOp(const substrait::Rel &sop) {
shared_ptr<Relation> SubstraitToDuckDB::TransformSetOp(const substrait::Rel &sop,
const google::protobuf::RepeatedPtrField<std::string> *names) {
D_ASSERT(sop.has_set());
auto &set = sop.set();
auto set_op_type = set.op();
Expand All @@ -653,31 +674,32 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformSetOp(const substrait::Rel &sop
throw NotImplementedException("The amount of inputs (%d) is not supported for this set operation", input_count);
}
auto lhs = TransformOp(inputs[0]);
auto rhs = TransformOp(inputs[1]);
auto rhs = TransformOp(inputs[1], names);

return make_shared_ptr<SetOpRelation>(std::move(lhs), std::move(rhs), type);
}

shared_ptr<Relation> SubstraitToDuckDB::TransformOp(const substrait::Rel &sop) {
shared_ptr<Relation> SubstraitToDuckDB::TransformOp(const substrait::Rel &sop,
const google::protobuf::RepeatedPtrField<std::string> *names) {
switch (sop.rel_type_case()) {
case substrait::Rel::RelTypeCase::kJoin:
return TransformJoinOp(sop);
case substrait::Rel::RelTypeCase::kCross:
return TransformCrossProductOp(sop);
case substrait::Rel::RelTypeCase::kFetch:
return TransformFetchOp(sop);
return TransformFetchOp(sop, names);
case substrait::Rel::RelTypeCase::kFilter:
return TransformFilterOp(sop);
case substrait::Rel::RelTypeCase::kProject:
return TransformProjectOp(sop);
return TransformProjectOp(sop, names);
case substrait::Rel::RelTypeCase::kAggregate:
return TransformAggregateOp(sop);
case substrait::Rel::RelTypeCase::kRead:
return TransformReadOp(sop);
case substrait::Rel::RelTypeCase::kSort:
return TransformSortOp(sop);
return TransformSortOp(sop, names);
case substrait::Rel::RelTypeCase::kSet:
return TransformSetOp(sop);
return TransformSetOp(sop, names);
default:
throw InternalException("Unsupported relation type " + to_string(sop.rel_type_case()));
}
Expand Down Expand Up @@ -719,7 +741,7 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformRootOp(const substrait::RelRoot
const auto &column_names = sop.names();
vector<unique_ptr<ParsedExpression>> expressions;
int id = 1;
auto child = TransformOp(sop.input());
auto child = TransformOp(sop.input(), &column_names);
auto first_projection_or_table = GetProjection(*child);
if (first_projection_or_table) {
vector<ColumnDefinition> *column_definitions = &first_projection_or_table->Cast<ProjectionRelation>().columns;
Expand Down
57 changes: 50 additions & 7 deletions src/include/from_substrait.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,42 @@

namespace duckdb {

struct RootNameIterator {
explicit RootNameIterator(const google::protobuf::RepeatedPtrField<std::string> *names) : names(names) {};
string GetCurrentName() const {
if (!names) {
return "";
}
if (iterator >= names->size()) {
throw InvalidInputException("Trying to access invalid root name at struct creation");
}
return (*names)[iterator];
}
void Next() {
++iterator;
}
bool Unique(idx_t count) const {
idx_t pos = iterator;
set<string> values;
for (idx_t i = 0; i < count; i++) {
if (values.find((*names)[pos]) != values.end()) {
return false;
}
values.insert((*names)[pos]);
pos++;
}
return true;
}
bool Finished() const {
if (!names) {
return true;
}
return iterator >= names->size();
}
const google::protobuf::RepeatedPtrField<std::string> *names = nullptr;
int iterator = 0;
};

class SubstraitToDuckDB {
public:
SubstraitToDuckDB(shared_ptr<ClientContext> &context_p, const string &serialized, bool json = false,
Expand All @@ -27,26 +63,33 @@ class SubstraitToDuckDB {
//! Transforms Substrait Plan Root To a DuckDB Relation
shared_ptr<Relation> TransformRootOp(const substrait::RelRoot &sop);
//! Transform Substrait Operations to DuckDB Relations
shared_ptr<Relation> TransformOp(const substrait::Rel &sop);
shared_ptr<Relation> TransformOp(const substrait::Rel &sop,
const google::protobuf::RepeatedPtrField<std::string> *names = nullptr);
shared_ptr<Relation> TransformJoinOp(const substrait::Rel &sop);
shared_ptr<Relation> TransformCrossProductOp(const substrait::Rel &sop);
shared_ptr<Relation> TransformFetchOp(const substrait::Rel &sop);
shared_ptr<Relation> TransformFetchOp(const substrait::Rel &sop,
const google::protobuf::RepeatedPtrField<std::string> *names = nullptr);
shared_ptr<Relation> TransformFilterOp(const substrait::Rel &sop);
shared_ptr<Relation> TransformProjectOp(const substrait::Rel &sop);
shared_ptr<Relation> TransformProjectOp(const substrait::Rel &sop,
const google::protobuf::RepeatedPtrField<std::string> *names = nullptr);
shared_ptr<Relation> TransformAggregateOp(const substrait::Rel &sop);
shared_ptr<Relation> TransformReadOp(const substrait::Rel &sop);
shared_ptr<Relation> TransformSortOp(const substrait::Rel &sop);
shared_ptr<Relation> TransformSetOp(const substrait::Rel &sop);
shared_ptr<Relation> TransformSortOp(const substrait::Rel &sop,
const google::protobuf::RepeatedPtrField<std::string> *names = nullptr);
shared_ptr<Relation> TransformSetOp(const substrait::Rel &sop,
const google::protobuf::RepeatedPtrField<std::string> *names = nullptr);

//! Transform Substrait Expressions to DuckDB Expressions
unique_ptr<ParsedExpression> TransformExpr(const substrait::Expression &sexpr);
unique_ptr<ParsedExpression> TransformExpr(const substrait::Expression &sexpr,
RootNameIterator *iterator = nullptr);
static unique_ptr<ParsedExpression> TransformLiteralExpr(const substrait::Expression &sexpr);
static unique_ptr<ParsedExpression> TransformSelectionExpr(const substrait::Expression &sexpr);
unique_ptr<ParsedExpression> TransformScalarFunctionExpr(const substrait::Expression &sexpr);
unique_ptr<ParsedExpression> TransformIfThenExpr(const substrait::Expression &sexpr);
unique_ptr<ParsedExpression> TransformCastExpr(const substrait::Expression &sexpr);
unique_ptr<ParsedExpression> TransformInExpr(const substrait::Expression &sexpr);
unique_ptr<ParsedExpression> TransformNested(const substrait::Expression &sexpr);
unique_ptr<ParsedExpression> TransformNested(const substrait::Expression &sexpr,
RootNameIterator *iterator = nullptr);

static void VerifyCorrectExtractSubfield(const string &subfield);
static string RemapFunctionName(const string &function_name);
Expand Down
148 changes: 147 additions & 1 deletion test/sql/test_nested_expressions.test
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,150 @@ statement ok
CALL get_substrait('SELECT row(row(row(a,a,10),row(a,a,10),row(a,a,10)),row(row(a,a,10),row(a,a,10),row(a,a,10)),row(row(a,a,10),row(a,a,10),row(a,a,10))) from t;')

statement ok
CALL get_substrait('SELECT [[[a,a,10], [a,a,10], [a,a,10]], [[a,a,10], [a,a,10], [a,a,10]], [[a,a,10], [a,a,10], [a,a,10]]] from t;')
CALL get_substrait('SELECT [[[a,a,10], [a,a,10], [a,a,10]], [[a,a,10], [a,a,10], [a,a,10]], [[a,a,10], [a,a,10], [a,a,10]]] from t;')

require tpch

statement ok
CALL dbgen(sf=0.01)

query I
CALL from_substrait_json('
{
"relations": [
{
"root": {
"input": {
"fetch": {
"common": {
"direct": {}
},
"input": {
"project": {
"common": {
"emit": {
"outputMapping": 8
}
},
"input": {
"read": {
"common": {
"direct": {}
},
"baseSchema": {
"names": [
"c_custkey",
"c_name",
"c_address",
"c_nationkey",
"c_phone",
"c_acctbal",
"c_mktsegment",
"c_comment"
],
"struct": {
"types": [
{
"i64": {
"nullability": "NULLABILITY_NULLABLE"
}
},
{
"string": {
"nullability": "NULLABILITY_NULLABLE"
}
},
{
"string": {
"nullability": "NULLABILITY_NULLABLE"
}
},
{
"i32": {
"nullability": "NULLABILITY_NULLABLE"
}
},
{
"string": {
"nullability": "NULLABILITY_NULLABLE"
}
},
{
"decimal": {
"scale": 2,
"precision": 15,
"nullability": "NULLABILITY_NULLABLE"
}
},
{
"string": {
"nullability": "NULLABILITY_NULLABLE"
}
},
{
"string": {
"nullability": "NULLABILITY_NULLABLE"
}
}
],
"nullability": "NULLABILITY_REQUIRED"
}
},
"namedTable": {
"names": [
"customer"
]
}
}
},
"expressions": [
{
"nested": {
"struct": {
"fields": [
{
"selection": {
"directReference": {
"structField": {}
},
"rootReference": {}
}
},
{
"selection": {
"directReference": {
"structField": {
"field": 1
}
},
"rootReference": {}
}
}
]
}
}
}
]
}
},
"count": 3
}
},
"names": [
"test_struct",
"custid",
"custname"
]
}
}
],
"version": {
"minorNumber": 52,
"producer": "spark-substrait-gateway"
}
}
')
----
{'custid': 1, 'custname': Customer#000000001}
{'custid': 2, 'custname': Customer#000000002}
{'custid': 3, 'custname': Customer#000000003}

0 comments on commit d488993

Please sign in to comment.