diff --git a/src/from_substrait.cpp b/src/from_substrait.cpp index 7719eb3..868247f 100644 --- a/src/from_substrait.cpp +++ b/src/from_substrait.cpp @@ -336,7 +336,8 @@ unique_ptr SubstraitToDuckDB::TransformInExpr(const substrait: return make_uniq(ExpressionType::COMPARE_IN, std::move(values)); } -unique_ptr SubstraitToDuckDB::TransformNested(const substrait::Expression &sexpr) { +unique_ptr SubstraitToDuckDB::TransformNested(const substrait::Expression &sexpr, + RootNameIterator *iterator) { auto &nested_expression = sexpr.nested(); if (nested_expression.has_struct_()) { auto &struct_expression = nested_expression.struct_(); @@ -344,7 +345,16 @@ unique_ptr SubstraitToDuckDB::TransformNested(const substrait: for (auto &child : struct_expression.fields()) { children.emplace_back(TransformExpr(child)); } - return make_uniq("row", std::move(children)); + if (iterator && !iterator->Finished() && iterator->Unique(children.size())) { + for (auto &child : children) { + child->alias = iterator->GetCurrentName(); + iterator->Next(); + } + return make_uniq("struct_pack", std::move(children)); + } else { + return make_uniq("row", std::move(children)); + } + } else if (nested_expression.has_list()) { auto &list_expression = nested_expression.list(); vector> children; @@ -366,7 +376,11 @@ unique_ptr SubstraitToDuckDB::TransformNested(const substrait: } } -unique_ptr SubstraitToDuckDB::TransformExpr(const substrait::Expression &sexpr) { +unique_ptr SubstraitToDuckDB::TransformExpr(const substrait::Expression &sexpr, + RootNameIterator *iterator) { + if (iterator) { + iterator->Next(); + } switch (sexpr.rex_type_case()) { case substrait::Expression::RexTypeCase::kLiteral: return TransformLiteralExpr(sexpr); @@ -381,7 +395,7 @@ unique_ptr SubstraitToDuckDB::TransformExpr(const substrait::E case substrait::Expression::RexTypeCase::kSingularOrList: return TransformInExpr(sexpr); case substrait::Expression::RexTypeCase::kNested: - return TransformNested(sexpr); + return TransformNested(sexpr, iterator); case substrait::Expression::RexTypeCase::kSubquery: default: throw InternalException("Unsupported expression type " + to_string(sexpr.rex_type_case())); @@ -463,11 +477,12 @@ shared_ptr SubstraitToDuckDB::TransformCrossProductOp(const substrait: TransformOp(sub_cross.right())->Alias("right")); } -shared_ptr SubstraitToDuckDB::TransformFetchOp(const substrait::Rel &sop) { +shared_ptr SubstraitToDuckDB::TransformFetchOp(const substrait::Rel &sop, + const google::protobuf::RepeatedPtrField *names) { auto &slimit = sop.fetch(); idx_t limit = slimit.count() == -1 ? NumericLimits::Maximum() : slimit.count(); idx_t offset = slimit.offset(); - return make_shared_ptr(TransformOp(slimit.input()), limit, offset); + return make_shared_ptr(TransformOp(slimit.input(), names), limit, offset); } shared_ptr SubstraitToDuckDB::TransformFilterOp(const substrait::Rel &sop) { @@ -475,10 +490,14 @@ shared_ptr SubstraitToDuckDB::TransformFilterOp(const substrait::Rel & return make_shared_ptr(TransformOp(sfilter.input()), TransformExpr(sfilter.condition())); } -shared_ptr SubstraitToDuckDB::TransformProjectOp(const substrait::Rel &sop) { +shared_ptr +SubstraitToDuckDB::TransformProjectOp(const substrait::Rel &sop, + const google::protobuf::RepeatedPtrField *names) { vector> expressions; + RootNameIterator iterator(names); + for (auto &sexpr : sop.project().expressions()) { - expressions.push_back(TransformExpr(sexpr)); + expressions.push_back(TransformExpr(sexpr, &iterator)); } vector mock_aliases; @@ -635,12 +654,13 @@ shared_ptr SubstraitToDuckDB::TransformReadOp(const substrait::Rel &so return scan; } -shared_ptr SubstraitToDuckDB::TransformSortOp(const substrait::Rel &sop) { +shared_ptr SubstraitToDuckDB::TransformSortOp(const substrait::Rel &sop, + const google::protobuf::RepeatedPtrField *names) { vector order_nodes; for (auto &sordf : sop.sort().sorts()) { order_nodes.push_back(TransformOrder(sordf)); } - return make_shared_ptr(TransformOp(sop.sort().input()), std::move(order_nodes)); + return make_shared_ptr(TransformOp(sop.sort().input(), names), std::move(order_nodes)); } static SetOperationType TransformSetOperationType(substrait::SetRel_SetOp setop) { @@ -660,7 +680,8 @@ static SetOperationType TransformSetOperationType(substrait::SetRel_SetOp setop) } } -shared_ptr SubstraitToDuckDB::TransformSetOp(const substrait::Rel &sop) { +shared_ptr SubstraitToDuckDB::TransformSetOp(const substrait::Rel &sop, + const google::protobuf::RepeatedPtrField *names) { D_ASSERT(sop.has_set()); auto &set = sop.set(); auto set_op_type = set.op(); @@ -672,31 +693,32 @@ shared_ptr SubstraitToDuckDB::TransformSetOp(const substrait::Rel &sop throw NotImplementedException("The amount of inputs (%d) is not supported for this set operation", input_count); } auto lhs = TransformOp(inputs[0]); - auto rhs = TransformOp(inputs[1]); + auto rhs = TransformOp(inputs[1], names); return make_shared_ptr(std::move(lhs), std::move(rhs), type); } -shared_ptr SubstraitToDuckDB::TransformOp(const substrait::Rel &sop) { +shared_ptr SubstraitToDuckDB::TransformOp(const substrait::Rel &sop, + const google::protobuf::RepeatedPtrField *names) { switch (sop.rel_type_case()) { case substrait::Rel::RelTypeCase::kJoin: return TransformJoinOp(sop); case substrait::Rel::RelTypeCase::kCross: return TransformCrossProductOp(sop); case substrait::Rel::RelTypeCase::kFetch: - return TransformFetchOp(sop); + return TransformFetchOp(sop, names); case substrait::Rel::RelTypeCase::kFilter: return TransformFilterOp(sop); case substrait::Rel::RelTypeCase::kProject: - return TransformProjectOp(sop); + return TransformProjectOp(sop, names); case substrait::Rel::RelTypeCase::kAggregate: return TransformAggregateOp(sop); case substrait::Rel::RelTypeCase::kRead: return TransformReadOp(sop); case substrait::Rel::RelTypeCase::kSort: - return TransformSortOp(sop); + return TransformSortOp(sop, names); case substrait::Rel::RelTypeCase::kSet: - return TransformSetOp(sop); + return TransformSetOp(sop, names); default: throw InternalException("Unsupported relation type " + to_string(sop.rel_type_case())); } @@ -738,7 +760,7 @@ shared_ptr SubstraitToDuckDB::TransformRootOp(const substrait::RelRoot const auto &column_names = sop.names(); vector> expressions; int id = 1; - auto child = TransformOp(sop.input()); + auto child = TransformOp(sop.input(), &column_names); auto first_projection_or_table = GetProjection(*child); if (first_projection_or_table) { vector *column_definitions = &first_projection_or_table->Cast().columns; diff --git a/src/include/from_substrait.hpp b/src/include/from_substrait.hpp index 9758bf2..95a4884 100644 --- a/src/include/from_substrait.hpp +++ b/src/include/from_substrait.hpp @@ -16,6 +16,42 @@ namespace duckdb { +struct RootNameIterator { + explicit RootNameIterator(const google::protobuf::RepeatedPtrField *names) : names(names) {}; + string GetCurrentName() const { + if (!names) { + return ""; + } + if (iterator >= names->size()) { + throw InvalidInputException("Trying to access invalid root name at struct creation"); + } + return (*names)[iterator]; + } + void Next() { + ++iterator; + } + bool Unique(idx_t count) const { + idx_t pos = iterator; + set values; + for (idx_t i = 0; i < count; i++) { + if (values.find((*names)[pos]) != values.end()) { + return false; + } + values.insert((*names)[pos]); + pos++; + } + return true; + } + bool Finished() const { + if (!names) { + return true; + } + return iterator >= names->size(); + } + const google::protobuf::RepeatedPtrField *names = nullptr; + int iterator = 0; +}; + class SubstraitToDuckDB { public: SubstraitToDuckDB(shared_ptr &context_p, const string &serialized, bool json = false, @@ -27,26 +63,33 @@ class SubstraitToDuckDB { //! Transforms Substrait Plan Root To a DuckDB Relation shared_ptr TransformRootOp(const substrait::RelRoot &sop); //! Transform Substrait Operations to DuckDB Relations - shared_ptr TransformOp(const substrait::Rel &sop); + shared_ptr TransformOp(const substrait::Rel &sop, + const google::protobuf::RepeatedPtrField *names = nullptr); shared_ptr TransformJoinOp(const substrait::Rel &sop); shared_ptr TransformCrossProductOp(const substrait::Rel &sop); - shared_ptr TransformFetchOp(const substrait::Rel &sop); + shared_ptr TransformFetchOp(const substrait::Rel &sop, + const google::protobuf::RepeatedPtrField *names = nullptr); shared_ptr TransformFilterOp(const substrait::Rel &sop); - shared_ptr TransformProjectOp(const substrait::Rel &sop); + shared_ptr TransformProjectOp(const substrait::Rel &sop, + const google::protobuf::RepeatedPtrField *names = nullptr); shared_ptr TransformAggregateOp(const substrait::Rel &sop); shared_ptr TransformReadOp(const substrait::Rel &sop); - shared_ptr TransformSortOp(const substrait::Rel &sop); - shared_ptr TransformSetOp(const substrait::Rel &sop); + shared_ptr TransformSortOp(const substrait::Rel &sop, + const google::protobuf::RepeatedPtrField *names = nullptr); + shared_ptr TransformSetOp(const substrait::Rel &sop, + const google::protobuf::RepeatedPtrField *names = nullptr); //! Transform Substrait Expressions to DuckDB Expressions - unique_ptr TransformExpr(const substrait::Expression &sexpr); + unique_ptr TransformExpr(const substrait::Expression &sexpr, + RootNameIterator *iterator = nullptr); static unique_ptr TransformLiteralExpr(const substrait::Expression &sexpr); static unique_ptr TransformSelectionExpr(const substrait::Expression &sexpr); unique_ptr TransformScalarFunctionExpr(const substrait::Expression &sexpr); unique_ptr TransformIfThenExpr(const substrait::Expression &sexpr); unique_ptr TransformCastExpr(const substrait::Expression &sexpr); unique_ptr TransformInExpr(const substrait::Expression &sexpr); - unique_ptr TransformNested(const substrait::Expression &sexpr); + unique_ptr TransformNested(const substrait::Expression &sexpr, + RootNameIterator *iterator = nullptr); static void VerifyCorrectExtractSubfield(const string &subfield); static string RemapFunctionName(const string &function_name); diff --git a/test/sql/test_nested_expressions.test b/test/sql/test_nested_expressions.test index 3f57181..54c3742 100644 --- a/test/sql/test_nested_expressions.test +++ b/test/sql/test_nested_expressions.test @@ -45,4 +45,150 @@ statement ok CALL get_substrait('SELECT row(row(row(a,a,10),row(a,a,10),row(a,a,10)),row(row(a,a,10),row(a,a,10),row(a,a,10)),row(row(a,a,10),row(a,a,10),row(a,a,10))) from t;') statement ok -CALL get_substrait('SELECT [[[a,a,10], [a,a,10], [a,a,10]], [[a,a,10], [a,a,10], [a,a,10]], [[a,a,10], [a,a,10], [a,a,10]]] from t;') \ No newline at end of file +CALL get_substrait('SELECT [[[a,a,10], [a,a,10], [a,a,10]], [[a,a,10], [a,a,10], [a,a,10]], [[a,a,10], [a,a,10], [a,a,10]]] from t;') + +require tpch + +statement ok +CALL dbgen(sf=0.01) + +query I +CALL from_substrait_json(' +{ + "relations": [ + { + "root": { + "input": { + "fetch": { + "common": { + "direct": {} + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": 8 + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "c_custkey", + "c_name", + "c_address", + "c_nationkey", + "c_phone", + "c_acctbal", + "c_mktsegment", + "c_comment" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "customer" + ] + } + } + }, + "expressions": [ + { + "nested": { + "struct": { + "fields": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + } + ] + } + } + } + ] + } + }, + "count": 3 + } + }, + "names": [ + "test_struct", + "custid", + "custname" + ] + } + } + ], + "version": { + "minorNumber": 52, + "producer": "spark-substrait-gateway" + } +} +') +---- +{'custid': 1, 'custname': Customer#000000001} +{'custid': 2, 'custname': Customer#000000002} +{'custid': 3, 'custname': Customer#000000003} \ No newline at end of file