Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for to_substrait and from_substrait for virtual table expression #130

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/main_distribution.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
name: Build extension binaries
uses: duckdb/extension-ci-tools/.github/workflows/_extension_distribution.yml@main
with:
duckdb_version: ca5af32c331f9d5ea49f7158d5c83a47f25b8b79
duckdb_version: c29c67bb971362cd1e9143305acffebb1bc9bd63
ci_tools_version: 5bdbe4d606d78dbd749f9578ba8ca639feece023
exclude_archs: "wasm_mvp;wasm_eh;wasm_threads;windows_amd64;windows_amd64_mingw;windows_amd64_rtools"
extension_name: substrait
Expand Down
2 changes: 1 addition & 1 deletion duckdb
Submodule duckdb updated 359 files
51 changes: 38 additions & 13 deletions src/from_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@
interval_t interval {};
interval.months = 0;
interval.days = literal.interval_day_to_second().days();
interval.micros = literal.interval_day_to_second().microseconds();

Check warning on line 165 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'microseconds' is deprecated [-Wdeprecated-declarations]

Check warning on line 165 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'microseconds' is deprecated [-Wdeprecated-declarations]

Check warning on line 165 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'microseconds' is deprecated [-Wdeprecated-declarations]

Check warning on line 165 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'microseconds' is deprecated [-Wdeprecated-declarations]
return Value::INTERVAL(interval);
}
default:
Expand Down Expand Up @@ -515,7 +515,7 @@

if (sop.aggregate().groupings_size() > 0) {
for (auto &sgrp : sop.aggregate().groupings()) {
for (auto &sgrpexpr : sgrp.grouping_expressions()) {

Check warning on line 518 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'grouping_expressions' is deprecated [-Wdeprecated-declarations]

Check warning on line 518 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'grouping_expressions' is deprecated [-Wdeprecated-declarations]
groups.push_back(TransformExpr(sgrpexpr));
expressions.push_back(TransformExpr(sgrpexpr));
}
Expand Down Expand Up @@ -615,23 +615,28 @@
scan = rel->Alias(name);
} else if (sget.has_virtual_table()) {
// We need to handle a virtual table as a LogicalExpressionGet
auto literal_values = sget.virtual_table().values();
vector<vector<Value>> expression_rows;
for (auto &row : literal_values) {
auto values = row.fields();
vector<Value> expression_row;
for (const auto &value : values) {
expression_row.emplace_back(TransformLiteralToValue(value));
if (!sget.virtual_table().values().empty()) {

Check warning on line 618 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'values' is deprecated [-Wdeprecated-declarations]

Check warning on line 618 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'values' is deprecated [-Wdeprecated-declarations]
auto literal_values = sget.virtual_table().values();

Check warning on line 619 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'values' is deprecated [-Wdeprecated-declarations]

Check warning on line 619 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'values' is deprecated [-Wdeprecated-declarations]
vector<vector<Value>> expression_rows;
for (auto &row : literal_values) {
auto values = row.fields();
vector<Value> expression_row;
for (const auto &value : values) {
expression_row.emplace_back(TransformLiteralToValue(value));
}
expression_rows.emplace_back(expression_row);
}
expression_rows.emplace_back(expression_row);
}
vector<string> column_names;
if (acquire_lock) {
scan = make_shared_ptr<ValueRelation>(context, expression_rows, column_names);
vector<string> column_names;
if (acquire_lock) {
scan = make_shared_ptr<ValueRelation>(context, expression_rows, column_names);

} else {
scan = make_shared_ptr<ValueRelation>(context_wrapper, expression_rows, column_names);
}
} else {
scan = make_shared_ptr<ValueRelation>(context_wrapper, expression_rows, column_names);
scan = GetValuesExpression(sget.virtual_table().expressions());
}

} else {
throw NotImplementedException("Unsupported type of read operator for substrait");
}
Expand All @@ -656,6 +661,26 @@
return scan;
}

shared_ptr<Relation> SubstraitToDuckDB::GetValuesExpression(const google::protobuf::RepeatedPtrField<substrait::Expression_Nested_Struct> &expression_rows) {
vector<vector<unique_ptr<ParsedExpression>>> expressions;
for (auto &row : expression_rows) {
vector<unique_ptr<ParsedExpression>> expression_row;
for (const auto &expr : row.fields()) {
expression_row.emplace_back(TransformExpr(expr));
}
expressions.emplace_back(std::move(expression_row));
}
vector<string> column_names;
shared_ptr<Relation> scan;
if (acquire_lock) {
scan = make_shared_ptr<ValueRelation>(context, std::move(expressions), column_names);
} else {
auto context_wrapper = make_shared_ptr<RelationContextWrapper>(context);
scan = make_shared_ptr<ValueRelation>(context_wrapper, std::move(expressions), column_names);
}
return scan;
}

shared_ptr<Relation> SubstraitToDuckDB::TransformSortOp(const substrait::Rel &sop,
const google::protobuf::RepeatedPtrField<std::string> *names) {
vector<OrderByNode> order_nodes;
Expand Down
1 change: 1 addition & 0 deletions src/include/from_substrait.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class SubstraitToDuckDB {
const google::protobuf::RepeatedPtrField<std::string> *names = nullptr);
shared_ptr<Relation> TransformAggregateOp(const substrait::Rel &sop);
shared_ptr<Relation> TransformReadOp(const substrait::Rel &sop);
shared_ptr<Relation> GetValuesExpression(const google::protobuf::RepeatedPtrField<substrait::Expression_Nested_Struct> &expression_rows);
shared_ptr<Relation> TransformSortOp(const substrait::Rel &sop,
const google::protobuf::RepeatedPtrField<std::string> *names = nullptr);
shared_ptr<Relation> TransformSetOp(const substrait::Rel &sop,
Expand Down
1 change: 1 addition & 0 deletions src/include/to_substrait.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class DuckDBToSubstrait {
substrait::Rel *TransformOrderBy(LogicalOperator &dop);
substrait::Rel *TransformComparisonJoin(LogicalOperator &dop);
substrait::Rel *TransformAggregateGroup(LogicalOperator &dop);
substrait::Rel *TransformExpressionGet(LogicalOperator &dop);
substrait::Rel *TransformGet(LogicalOperator &dop);
substrait::Rel *TransformCrossProduct(LogicalOperator &dop);
substrait::Rel *TransformUnion(LogicalOperator &dop);
Expand Down
21 changes: 21 additions & 0 deletions src/to_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@
} else {
auto interval_day = make_uniq<substrait::Expression_Literal_IntervalDayToSecond>();
interval_day->set_days(dval.GetValue<interval_t>().days);
interval_day->set_microseconds(static_cast<int32_t>(dval.GetValue<interval_t>().micros));

Check warning on line 204 in src/to_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'set_microseconds' is deprecated [-Wdeprecated-declarations]

Check warning on line 204 in src/to_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'set_microseconds' is deprecated [-Wdeprecated-declarations]
sval.set_allocated_interval_day_to_second(interval_day.release());
}
}
Expand Down Expand Up @@ -1012,7 +1012,7 @@
// TODO push projection or push substrait to allow expressions here
throw NotImplementedException("No expressions in groupings yet");
}
TransformExpr(*dgrp, *sgrp->add_grouping_expressions());

Check warning on line 1015 in src/to_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'add_grouping_expressions' is deprecated [-Wdeprecated-declarations]

Check warning on line 1015 in src/to_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'add_grouping_expressions' is deprecated [-Wdeprecated-declarations]
}
for (auto &dmeas : daggr.expressions) {
auto smeas = saggr->add_measures()->mutable_measure();
Expand Down Expand Up @@ -1280,7 +1280,7 @@
auto virtual_table = sget->mutable_virtual_table();

// Add a dummy value to emit one row
auto dummy_value = virtual_table->add_values();

Check warning on line 1283 in src/to_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'add_values' is deprecated [-Wdeprecated-declarations]

Check warning on line 1283 in src/to_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'add_values' is deprecated [-Wdeprecated-declarations]
dummy_value->add_fields()->set_i32(42);
return get_rel;
}
Expand Down Expand Up @@ -1339,6 +1339,25 @@
return get_rel;
}

substrait::Rel *DuckDBToSubstrait::TransformExpressionGet(LogicalOperator &dop) {
auto get_rel = new substrait::Rel();
auto &dget = dop.Cast<LogicalExpressionGet>();

auto sget = get_rel->mutable_read();
auto virtual_table = sget->mutable_virtual_table();

for (auto &row : dget.expressions) {
auto row_item = virtual_table->add_expressions();
for (auto &expr : row) {
auto s_expr = new substrait::Expression();
TransformExpr(*expr, *s_expr);
*row_item->add_fields() = *s_expr;
delete s_expr;
}
}
return get_rel;
}

substrait::Rel *DuckDBToSubstrait::TransformCrossProduct(LogicalOperator &dop) {
auto rel = new substrait::Rel();
auto sub_cross_prod = rel->mutable_cross();
Expand Down Expand Up @@ -1537,6 +1556,8 @@
return TransformAggregateGroup(dop);
case LogicalOperatorType::LOGICAL_GET:
return TransformGet(dop);
case LogicalOperatorType::LOGICAL_EXPRESSION_GET:
return TransformExpressionGet(dop);
case LogicalOperatorType::LOGICAL_CROSS_PRODUCT:
return TransformCrossProduct(dop);
case LogicalOperatorType::LOGICAL_UNION:
Expand Down
27 changes: 27 additions & 0 deletions test/c/test_substrait_c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <chrono>
#include <thread>
#include <iostream>

using namespace duckdb;
using namespace std;
Expand Down Expand Up @@ -293,3 +294,29 @@ TEST_CASE("Test C DeleteRows with Substrait API", "[substrait-api]") {
REQUIRE(CHECK_COLUMN(result, 2, {1, 2, 3}));
REQUIRE(CHECK_COLUMN(result, 3, {120000, 80000, 95000}));
}

TEST_CASE("Test C VirtualTable input Literal", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);

auto json = con.GetSubstraitJSON("select * from (values (1, 2),(3, 4))");
REQUIRE(!json.empty());
std::cout << json << std::endl;

auto result = con.FromSubstraitJSON(json);
REQUIRE(CHECK_COLUMN(result, 0, {1, 3}));
REQUIRE(CHECK_COLUMN(result, 1, {2, 4}));
}

TEST_CASE("Test C VirtualTable input Expression", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);

auto json = con.GetSubstraitJSON("select * from (values (1+1,2+2),(3+3,4+4)) as temp(a,b)");
REQUIRE(!json.empty());
std::cout << json << std::endl;

auto result = con.FromSubstraitJSON(json);
REQUIRE(CHECK_COLUMN(result, 0, {2, 6}));
REQUIRE(CHECK_COLUMN(result, 1, {4, 8}));
}
Loading