From 650b899a0e8e7a4eb67a48358ba33660eba095de Mon Sep 17 00:00:00 2001 From: Anshul Data Date: Thu, 14 Nov 2024 10:12:55 +0530 Subject: [PATCH] feat: Add support for to_substrait for virtual table expression --- src/include/to_substrait.hpp | 1 + src/to_substrait.cpp | 21 +++++++++++++++++++++ test/c/test_substrait_c_api.cpp | 21 +++++++++++++++++++++ 3 files changed, 43 insertions(+) diff --git a/src/include/to_substrait.hpp b/src/include/to_substrait.hpp index aa063b8..30db7ed 100644 --- a/src/include/to_substrait.hpp +++ b/src/include/to_substrait.hpp @@ -59,6 +59,7 @@ class DuckDBToSubstrait { substrait::Rel *TransformOrderBy(LogicalOperator &dop); substrait::Rel *TransformComparisonJoin(LogicalOperator &dop); substrait::Rel *TransformAggregateGroup(LogicalOperator &dop); + substrait::Rel *TransformExpressionGet(LogicalOperator &dop); substrait::Rel *TransformGet(LogicalOperator &dop); substrait::Rel *TransformCrossProduct(LogicalOperator &dop); substrait::Rel *TransformUnion(LogicalOperator &dop); diff --git a/src/to_substrait.cpp b/src/to_substrait.cpp index 5fb7f2b..53c97c6 100644 --- a/src/to_substrait.cpp +++ b/src/to_substrait.cpp @@ -1339,6 +1339,25 @@ substrait::Rel *DuckDBToSubstrait::TransformGet(LogicalOperator &dop) { return get_rel; } +substrait::Rel *DuckDBToSubstrait::TransformExpressionGet(LogicalOperator &dop) { + auto get_rel = new substrait::Rel(); + auto &dget = dop.Cast(); + + auto sget = get_rel->mutable_read(); + auto virtual_table = sget->mutable_virtual_table(); + + for (auto &row : dget.expressions) { + auto row_item = virtual_table->add_expressions(); + for (auto &expr : row) { + auto s_expr = new substrait::Expression(); + TransformExpr(*expr, *s_expr); + *row_item->add_fields() = *s_expr; + delete s_expr; + } + } + return get_rel; +} + substrait::Rel *DuckDBToSubstrait::TransformCrossProduct(LogicalOperator &dop) { auto rel = new substrait::Rel(); auto sub_cross_prod = rel->mutable_cross(); @@ -1435,6 +1454,8 @@ substrait::Rel *DuckDBToSubstrait::TransformOp(LogicalOperator &dop) { return TransformAggregateGroup(dop); case LogicalOperatorType::LOGICAL_GET: return TransformGet(dop); + case LogicalOperatorType::LOGICAL_EXPRESSION_GET: + return TransformExpressionGet(dop); case LogicalOperatorType::LOGICAL_CROSS_PRODUCT: return TransformCrossProduct(dop); case LogicalOperatorType::LOGICAL_UNION: diff --git a/test/c/test_substrait_c_api.cpp b/test/c/test_substrait_c_api.cpp index fe79d65..6307bf1 100644 --- a/test/c/test_substrait_c_api.cpp +++ b/test/c/test_substrait_c_api.cpp @@ -4,6 +4,7 @@ #include #include +#include using namespace duckdb; using namespace std; @@ -45,3 +46,23 @@ TEST_CASE("Test C Get and To Json-Substrait API", "[substrait-api]") { REQUIRE_THROWS(con.FromSubstraitJSON("this is not valid")); } + + +TEST_CASE("Test C VirtualTable input Literal", "[substrait-api]") { + DuckDB db(nullptr); + Connection con(db); + + auto json = con.GetSubstraitJSON("select * from (values (1, 2),(3, 4))"); + REQUIRE(!json.empty()); + std::cout << json << std::endl; +} + +TEST_CASE("Test C VirtualTable input Expression", "[substrait-api]") { + DuckDB db(nullptr); + Connection con(db); + + auto json = con.GetSubstraitJSON("select * from (values (1+1,2+2),(3+3,4+4)) as temp(a,b)"); + REQUIRE(!json.empty()); + std::cout << json << std::endl; + +}