Skip to content

Commit

Permalink
Merge pull request #115 from pdet/structs
Browse files Browse the repository at this point in the history
Structs - Fix
  • Loading branch information
pdet authored Nov 6, 2024
2 parents bc9f4b3 + d07d339 commit 918a12d
Show file tree
Hide file tree
Showing 3 changed files with 237 additions and 26 deletions.
58 changes: 40 additions & 18 deletions src/from_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,15 +336,25 @@ unique_ptr<ParsedExpression> SubstraitToDuckDB::TransformInExpr(const substrait:
return make_uniq<OperatorExpression>(ExpressionType::COMPARE_IN, std::move(values));
}

unique_ptr<ParsedExpression> SubstraitToDuckDB::TransformNested(const substrait::Expression &sexpr) {
unique_ptr<ParsedExpression> SubstraitToDuckDB::TransformNested(const substrait::Expression &sexpr,
RootNameIterator *iterator) {
auto &nested_expression = sexpr.nested();
if (nested_expression.has_struct_()) {
auto &struct_expression = nested_expression.struct_();
vector<unique_ptr<ParsedExpression>> children;
for (auto &child : struct_expression.fields()) {
children.emplace_back(TransformExpr(child));
}
return make_uniq<FunctionExpression>("row", std::move(children));
if (iterator && !iterator->Finished() && iterator->Unique(children.size())) {
for (auto &child : children) {
child->alias = iterator->GetCurrentName();
iterator->Next();
}
return make_uniq<FunctionExpression>("struct_pack", std::move(children));
} else {
return make_uniq<FunctionExpression>("row", std::move(children));
}

} else if (nested_expression.has_list()) {
auto &list_expression = nested_expression.list();
vector<unique_ptr<ParsedExpression>> children;
Expand All @@ -366,7 +376,11 @@ unique_ptr<ParsedExpression> SubstraitToDuckDB::TransformNested(const substrait:
}
}

unique_ptr<ParsedExpression> SubstraitToDuckDB::TransformExpr(const substrait::Expression &sexpr) {
unique_ptr<ParsedExpression> SubstraitToDuckDB::TransformExpr(const substrait::Expression &sexpr,
RootNameIterator *iterator) {
if (iterator) {
iterator->Next();
}
switch (sexpr.rex_type_case()) {
case substrait::Expression::RexTypeCase::kLiteral:
return TransformLiteralExpr(sexpr);
Expand All @@ -381,7 +395,7 @@ unique_ptr<ParsedExpression> SubstraitToDuckDB::TransformExpr(const substrait::E
case substrait::Expression::RexTypeCase::kSingularOrList:
return TransformInExpr(sexpr);
case substrait::Expression::RexTypeCase::kNested:
return TransformNested(sexpr);
return TransformNested(sexpr, iterator);
case substrait::Expression::RexTypeCase::kSubquery:
default:
throw InternalException("Unsupported expression type " + to_string(sexpr.rex_type_case()));
Expand Down Expand Up @@ -463,22 +477,27 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformCrossProductOp(const substrait:
TransformOp(sub_cross.right())->Alias("right"));
}

shared_ptr<Relation> SubstraitToDuckDB::TransformFetchOp(const substrait::Rel &sop) {
shared_ptr<Relation> SubstraitToDuckDB::TransformFetchOp(const substrait::Rel &sop,
const google::protobuf::RepeatedPtrField<std::string> *names) {
auto &slimit = sop.fetch();
idx_t limit = slimit.count() == -1 ? NumericLimits<idx_t>::Maximum() : slimit.count();
idx_t offset = slimit.offset();
return make_shared_ptr<LimitRelation>(TransformOp(slimit.input()), limit, offset);
return make_shared_ptr<LimitRelation>(TransformOp(slimit.input(), names), limit, offset);
}

shared_ptr<Relation> SubstraitToDuckDB::TransformFilterOp(const substrait::Rel &sop) {
auto &sfilter = sop.filter();
return make_shared_ptr<FilterRelation>(TransformOp(sfilter.input()), TransformExpr(sfilter.condition()));
}

shared_ptr<Relation> SubstraitToDuckDB::TransformProjectOp(const substrait::Rel &sop) {
shared_ptr<Relation>
SubstraitToDuckDB::TransformProjectOp(const substrait::Rel &sop,
const google::protobuf::RepeatedPtrField<std::string> *names) {
vector<unique_ptr<ParsedExpression>> expressions;
RootNameIterator iterator(names);

for (auto &sexpr : sop.project().expressions()) {
expressions.push_back(TransformExpr(sexpr));
expressions.push_back(TransformExpr(sexpr, &iterator));
}

vector<string> mock_aliases;
Expand Down Expand Up @@ -635,12 +654,13 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformReadOp(const substrait::Rel &so
return scan;
}

shared_ptr<Relation> SubstraitToDuckDB::TransformSortOp(const substrait::Rel &sop) {
shared_ptr<Relation> SubstraitToDuckDB::TransformSortOp(const substrait::Rel &sop,
const google::protobuf::RepeatedPtrField<std::string> *names) {
vector<OrderByNode> order_nodes;
for (auto &sordf : sop.sort().sorts()) {
order_nodes.push_back(TransformOrder(sordf));
}
return make_shared_ptr<OrderRelation>(TransformOp(sop.sort().input()), std::move(order_nodes));
return make_shared_ptr<OrderRelation>(TransformOp(sop.sort().input(), names), std::move(order_nodes));
}

static SetOperationType TransformSetOperationType(substrait::SetRel_SetOp setop) {
Expand All @@ -660,7 +680,8 @@ static SetOperationType TransformSetOperationType(substrait::SetRel_SetOp setop)
}
}

shared_ptr<Relation> SubstraitToDuckDB::TransformSetOp(const substrait::Rel &sop) {
shared_ptr<Relation> SubstraitToDuckDB::TransformSetOp(const substrait::Rel &sop,
const google::protobuf::RepeatedPtrField<std::string> *names) {
D_ASSERT(sop.has_set());
auto &set = sop.set();
auto set_op_type = set.op();
Expand All @@ -672,31 +693,32 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformSetOp(const substrait::Rel &sop
throw NotImplementedException("The amount of inputs (%d) is not supported for this set operation", input_count);
}
auto lhs = TransformOp(inputs[0]);
auto rhs = TransformOp(inputs[1]);
auto rhs = TransformOp(inputs[1], names);

return make_shared_ptr<SetOpRelation>(std::move(lhs), std::move(rhs), type);
}

shared_ptr<Relation> SubstraitToDuckDB::TransformOp(const substrait::Rel &sop) {
shared_ptr<Relation> SubstraitToDuckDB::TransformOp(const substrait::Rel &sop,
const google::protobuf::RepeatedPtrField<std::string> *names) {
switch (sop.rel_type_case()) {
case substrait::Rel::RelTypeCase::kJoin:
return TransformJoinOp(sop);
case substrait::Rel::RelTypeCase::kCross:
return TransformCrossProductOp(sop);
case substrait::Rel::RelTypeCase::kFetch:
return TransformFetchOp(sop);
return TransformFetchOp(sop, names);
case substrait::Rel::RelTypeCase::kFilter:
return TransformFilterOp(sop);
case substrait::Rel::RelTypeCase::kProject:
return TransformProjectOp(sop);
return TransformProjectOp(sop, names);
case substrait::Rel::RelTypeCase::kAggregate:
return TransformAggregateOp(sop);
case substrait::Rel::RelTypeCase::kRead:
return TransformReadOp(sop);
case substrait::Rel::RelTypeCase::kSort:
return TransformSortOp(sop);
return TransformSortOp(sop, names);
case substrait::Rel::RelTypeCase::kSet:
return TransformSetOp(sop);
return TransformSetOp(sop, names);
default:
throw InternalException("Unsupported relation type " + to_string(sop.rel_type_case()));
}
Expand Down Expand Up @@ -738,7 +760,7 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformRootOp(const substrait::RelRoot
const auto &column_names = sop.names();
vector<unique_ptr<ParsedExpression>> expressions;
int id = 1;
auto child = TransformOp(sop.input());
auto child = TransformOp(sop.input(), &column_names);
auto first_projection_or_table = GetProjection(*child);
if (first_projection_or_table) {
vector<ColumnDefinition> *column_definitions = &first_projection_or_table->Cast<ProjectionRelation>().columns;
Expand Down
57 changes: 50 additions & 7 deletions src/include/from_substrait.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,42 @@

namespace duckdb {

struct RootNameIterator {
explicit RootNameIterator(const google::protobuf::RepeatedPtrField<std::string> *names) : names(names) {};
string GetCurrentName() const {
if (!names) {
return "";
}
if (iterator >= names->size()) {
throw InvalidInputException("Trying to access invalid root name at struct creation");
}
return (*names)[iterator];
}
void Next() {
++iterator;
}
bool Unique(idx_t count) const {
idx_t pos = iterator;
set<string> values;
for (idx_t i = 0; i < count; i++) {
if (values.find((*names)[pos]) != values.end()) {
return false;
}
values.insert((*names)[pos]);
pos++;
}
return true;
}
bool Finished() const {
if (!names) {
return true;
}
return iterator >= names->size();
}
const google::protobuf::RepeatedPtrField<std::string> *names = nullptr;
int iterator = 0;
};

class SubstraitToDuckDB {
public:
SubstraitToDuckDB(shared_ptr<ClientContext> &context_p, const string &serialized, bool json = false,
Expand All @@ -27,26 +63,33 @@ class SubstraitToDuckDB {
//! Transforms Substrait Plan Root To a DuckDB Relation
shared_ptr<Relation> TransformRootOp(const substrait::RelRoot &sop);
//! Transform Substrait Operations to DuckDB Relations
shared_ptr<Relation> TransformOp(const substrait::Rel &sop);
shared_ptr<Relation> TransformOp(const substrait::Rel &sop,
const google::protobuf::RepeatedPtrField<std::string> *names = nullptr);
shared_ptr<Relation> TransformJoinOp(const substrait::Rel &sop);
shared_ptr<Relation> TransformCrossProductOp(const substrait::Rel &sop);
shared_ptr<Relation> TransformFetchOp(const substrait::Rel &sop);
shared_ptr<Relation> TransformFetchOp(const substrait::Rel &sop,
const google::protobuf::RepeatedPtrField<std::string> *names = nullptr);
shared_ptr<Relation> TransformFilterOp(const substrait::Rel &sop);
shared_ptr<Relation> TransformProjectOp(const substrait::Rel &sop);
shared_ptr<Relation> TransformProjectOp(const substrait::Rel &sop,
const google::protobuf::RepeatedPtrField<std::string> *names = nullptr);
shared_ptr<Relation> TransformAggregateOp(const substrait::Rel &sop);
shared_ptr<Relation> TransformReadOp(const substrait::Rel &sop);
shared_ptr<Relation> TransformSortOp(const substrait::Rel &sop);
shared_ptr<Relation> TransformSetOp(const substrait::Rel &sop);
shared_ptr<Relation> TransformSortOp(const substrait::Rel &sop,
const google::protobuf::RepeatedPtrField<std::string> *names = nullptr);
shared_ptr<Relation> TransformSetOp(const substrait::Rel &sop,
const google::protobuf::RepeatedPtrField<std::string> *names = nullptr);

//! Transform Substrait Expressions to DuckDB Expressions
unique_ptr<ParsedExpression> TransformExpr(const substrait::Expression &sexpr);
unique_ptr<ParsedExpression> TransformExpr(const substrait::Expression &sexpr,
RootNameIterator *iterator = nullptr);
static unique_ptr<ParsedExpression> TransformLiteralExpr(const substrait::Expression &sexpr);
static unique_ptr<ParsedExpression> TransformSelectionExpr(const substrait::Expression &sexpr);
unique_ptr<ParsedExpression> TransformScalarFunctionExpr(const substrait::Expression &sexpr);
unique_ptr<ParsedExpression> TransformIfThenExpr(const substrait::Expression &sexpr);
unique_ptr<ParsedExpression> TransformCastExpr(const substrait::Expression &sexpr);
unique_ptr<ParsedExpression> TransformInExpr(const substrait::Expression &sexpr);
unique_ptr<ParsedExpression> TransformNested(const substrait::Expression &sexpr);
unique_ptr<ParsedExpression> TransformNested(const substrait::Expression &sexpr,
RootNameIterator *iterator = nullptr);

static void VerifyCorrectExtractSubfield(const string &subfield);
static string RemapFunctionName(const string &function_name);
Expand Down
148 changes: 147 additions & 1 deletion test/sql/test_nested_expressions.test
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,150 @@ statement ok
CALL get_substrait('SELECT row(row(row(a,a,10),row(a,a,10),row(a,a,10)),row(row(a,a,10),row(a,a,10),row(a,a,10)),row(row(a,a,10),row(a,a,10),row(a,a,10))) from t;')

statement ok
CALL get_substrait('SELECT [[[a,a,10], [a,a,10], [a,a,10]], [[a,a,10], [a,a,10], [a,a,10]], [[a,a,10], [a,a,10], [a,a,10]]] from t;')
CALL get_substrait('SELECT [[[a,a,10], [a,a,10], [a,a,10]], [[a,a,10], [a,a,10], [a,a,10]], [[a,a,10], [a,a,10], [a,a,10]]] from t;')

require tpch

statement ok
CALL dbgen(sf=0.01)

query I
CALL from_substrait_json('
{
"relations": [
{
"root": {
"input": {
"fetch": {
"common": {
"direct": {}
},
"input": {
"project": {
"common": {
"emit": {
"outputMapping": 8
}
},
"input": {
"read": {
"common": {
"direct": {}
},
"baseSchema": {
"names": [
"c_custkey",
"c_name",
"c_address",
"c_nationkey",
"c_phone",
"c_acctbal",
"c_mktsegment",
"c_comment"
],
"struct": {
"types": [
{
"i64": {
"nullability": "NULLABILITY_NULLABLE"
}
},
{
"string": {
"nullability": "NULLABILITY_NULLABLE"
}
},
{
"string": {
"nullability": "NULLABILITY_NULLABLE"
}
},
{
"i32": {
"nullability": "NULLABILITY_NULLABLE"
}
},
{
"string": {
"nullability": "NULLABILITY_NULLABLE"
}
},
{
"decimal": {
"scale": 2,
"precision": 15,
"nullability": "NULLABILITY_NULLABLE"
}
},
{
"string": {
"nullability": "NULLABILITY_NULLABLE"
}
},
{
"string": {
"nullability": "NULLABILITY_NULLABLE"
}
}
],
"nullability": "NULLABILITY_REQUIRED"
}
},
"namedTable": {
"names": [
"customer"
]
}
}
},
"expressions": [
{
"nested": {
"struct": {
"fields": [
{
"selection": {
"directReference": {
"structField": {}
},
"rootReference": {}
}
},
{
"selection": {
"directReference": {
"structField": {
"field": 1
}
},
"rootReference": {}
}
}
]
}
}
}
]
}
},
"count": 3
}
},
"names": [
"test_struct",
"custid",
"custname"
]
}
}
],
"version": {
"minorNumber": 52,
"producer": "spark-substrait-gateway"
}
}
')
----
{'custid': 1, 'custname': Customer#000000001}
{'custid': 2, 'custname': Customer#000000002}
{'custid': 3, 'custname': Customer#000000003}

0 comments on commit 918a12d

Please sign in to comment.