Skip to content

Commit

Permalink
Woo, queries 4 and 17 with delimiter joins now seem to work
Browse files Browse the repository at this point in the history
  • Loading branch information
pdet committed Jul 10, 2024
1 parent db7f93e commit a8529d1
Show file tree
Hide file tree
Showing 7 changed files with 345 additions and 311 deletions.
20 changes: 10 additions & 10 deletions src/from_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -438,22 +438,22 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformDelimJoinOp(const substrait::Re

JoinType djointype;
switch (sjoin.type()) {
case substrait::JoinRel::JoinType::JoinRel_JoinType_JOIN_TYPE_INNER:
case substrait::DelimiterJoinRel_JoinType::DelimiterJoinRel_JoinType_JOIN_TYPE_INNER:
djointype = JoinType::INNER;
break;
case substrait::JoinRel::JoinType::JoinRel_JoinType_JOIN_TYPE_LEFT:
case substrait::DelimiterJoinRel_JoinType::DelimiterJoinRel_JoinType_JOIN_TYPE_LEFT:
djointype = JoinType::LEFT;
break;
case substrait::JoinRel::JoinType::JoinRel_JoinType_JOIN_TYPE_RIGHT:
case substrait::DelimiterJoinRel_JoinType::DelimiterJoinRel_JoinType_JOIN_TYPE_RIGHT:
djointype = JoinType::RIGHT;
break;
case substrait::JoinRel::JoinType::JoinRel_JoinType_JOIN_TYPE_SINGLE:
case substrait::DelimiterJoinRel_JoinType::DelimiterJoinRel_JoinType_JOIN_TYPE_SINGLE:
djointype = JoinType::SINGLE;
break;
case substrait::JoinRel::JoinType::JoinRel_JoinType_JOIN_TYPE_SEMI:
djointype = JoinType::SEMI;
case substrait::DelimiterJoinRel_JoinType::DelimiterJoinRel_JoinType_JOIN_TYPE_RIGHT_SEMI:
djointype = JoinType::RIGHT_SEMI;
break;
case substrait::JoinRel::JoinType::JoinRel_JoinType_JOIN_TYPE_MARK:
case substrait::DelimiterJoinRel_JoinType::DelimiterJoinRel_JoinType_JOIN_TYPE_MARK:
djointype = JoinType::MARK;
break;
default:
Expand All @@ -462,9 +462,10 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformDelimJoinOp(const substrait::Re
unique_ptr<ParsedExpression> join_condition = TransformExpr(sjoin.expression());
auto left_op = TransformOp(sjoin.left())->Alias("left");
auto right_op = TransformOp(sjoin.right())->Alias("right");
auto join = make_shared_ptr<JoinRelation>(std::move(left_op), std::move(right_op), std::move(join_condition), djointype);
auto join =
make_shared_ptr<JoinRelation>(std::move(left_op), std::move(right_op), std::move(join_condition), djointype);
join->delim_flipped = sjoin.delim_flipped();
for (auto& col: sjoin.duplicate_eliminated_columns()) {
for (auto &col : sjoin.duplicate_eliminated_columns()) {
join->duplicate_eliminated_columns.emplace_back(TransformExpr(col));
}
return join;
Expand Down Expand Up @@ -699,7 +700,6 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformPlan() {
plan.DebugString() +
"Not Implemented error message: " + parsed_error.RawMessage());
}

return d_plan;
}

Expand Down
10 changes: 7 additions & 3 deletions src/substrait_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,11 @@ shared_ptr<Relation> SubstraitPlanToDuckDBRel(Connection &conn, const string &se

static void VerifySubstraitRoundtrip(unique_ptr<LogicalOperator> &query_plan, Connection &con,
ToSubstraitFunctionData &data, const string &serialized, bool is_json) {
// We round-trip the generated json and verify if the result is the same
auto actual_result = con.Query(data.query);

bool is_optimizer_enabled = con.context->config.enable_optimizer;
con.context->config.enable_optimizer = false;
auto sub_relation = SubstraitPlanToDuckDBRel(con, serialized, is_json);
unique_ptr<QueryResult> substrait_result;

try {
substrait_result = sub_relation->Execute();
} catch (std::exception &ex) {
Expand All @@ -102,6 +102,10 @@ static void VerifySubstraitRoundtrip(unique_ptr<LogicalOperator> &query_plan, Co
sub_relation->Print();
throw InternalException("Substrait Plan Execution Failed");
}
con.context->config.enable_optimizer = is_optimizer_enabled;
// We round-trip the generated json and verify if the result is the same
auto actual_result = con.Query(data.query);

substrait_result->names = actual_result->names;
unique_ptr<MaterializedQueryResult> substrait_materialized;

Expand Down
29 changes: 20 additions & 9 deletions src/to_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -949,9 +949,12 @@ substrait::Rel *DuckDBToSubstrait::TransformDelimiterJoin(LogicalOperator &dop)
case JoinType::SINGLE:
sjoin->set_type(substrait::DelimiterJoinRel_JoinType::DelimiterJoinRel_JoinType_JOIN_TYPE_SINGLE);
break;
// case JoinType::SINGLE:
// sjoin->set_type(substrait::DelimiterJoinRel_JoinType::DelimiterJoinRel_JoinType_JOIN_TYPE_SINGLE);
// break;
case JoinType::RIGHT_SEMI:
sjoin->set_type(substrait::DelimiterJoinRel_JoinType::DelimiterJoinRel_JoinType_JOIN_TYPE_RIGHT_SEMI);
break;
case JoinType::MARK:
sjoin->set_type(substrait::DelimiterJoinRel_JoinType::DelimiterJoinRel_JoinType_JOIN_TYPE_MARK);
break;
default:
throw InternalException("Unsupported join type " + JoinTypeToString(djoin.join_type));
}
Expand All @@ -976,14 +979,22 @@ substrait::Rel *DuckDBToSubstrait::TransformDelimiterJoin(LogicalOperator &dop)
}
auto proj_rel = new substrait::Rel();
auto projection = proj_rel->mutable_project();
for (auto left_idx : djoin.left_projection_map) {
CreateFieldRef(projection->add_expressions(), left_idx);
}
if (djoin.join_type != JoinType::SEMI) {
for (auto right_idx : djoin.right_projection_map) {
CreateFieldRef(projection->add_expressions(), right_idx + left_col_count);
if (djoin.join_type == JoinType::RIGHT_SEMI) {
// We project everything from the right table
for (uint64_t i = 0; i < dop.children[1]->types.size(); i++) {
CreateFieldRef(projection->add_expressions(), i);
}
} else {
for (auto left_idx : djoin.left_projection_map) {
CreateFieldRef(projection->add_expressions(), left_idx);
}
if (djoin.join_type != JoinType::SEMI) {
for (auto right_idx : djoin.right_projection_map) {
CreateFieldRef(projection->add_expressions(), right_idx + left_col_count);
}
}
}

projection->set_allocated_input(res);
return proj_rel;
}
Expand Down
22 changes: 15 additions & 7 deletions test/sql/test_substrait_subqueries.test
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,36 @@ statement ok
PRAGMA enable_verification

statement ok
CREATE TABLE integers_2 (i integer, j integer);
CREATE TABLE integers (i integer);

statement ok
INSERT INTO integers_2 VALUES (NULL,1)
INSERT INTO integers VALUES (NULL)

statement ok
CALL get_substrait('SELECT i, (select SUM(i) OR true from integers_2 where j = i1.j) FROM integers_2 i1;')
CREATE TABLE integers_2 (j integer);

statement ok
INSERT INTO integers_2 VALUES (NULL)

statement ok
CREATE TABLE integers (i integer);
CALL get_substrait('SELECT (select SUM(i) OR true as i from integers where i = integers_2.j) FROM integers_2;')

statement ok
drop table integers

statement ok
insert into integers values (1),(2),(3),(NULL);
CREATE TABLE integers (i integer, j integer);

statement ok
insert into integers values (1,1),(2,2),(3,3),(NULL,NULL);

# Uncorrelated Scalar
statement ok
CALL get_substrait('select i, i + (select MIN(i) from integers) from integers order by i')

# Uncorrelated Any (Missing Mark Join)
statement ok
CALL get_substrait('select i = ANY(select * from integers where i is not null) from integers')
#statement ok
#CALL get_substrait('select i = ANY(select * from integers where i is not null) from integers')

# Uncorrelated Exist
statement ok
Expand Down
24 changes: 12 additions & 12 deletions test/sql/test_substrait_tpch.test
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,23 @@ statement ok
PRAGMA enable_verification

statement ok
CALL dbgen(sf=0.01)
CALL dbgen(sf=0.001)

#Q 01
#statement ok
#CALL get_substrait('SELECT l_returnflag, l_linestatus, sum(l_quantity) AS sum_qty, sum(l_extendedprice) AS sum_base_price, sum(l_extendedprice * (1 - l_discount)) AS sum_disc_price, sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) AS sum_charge, avg(l_quantity) AS avg_qty, avg(l_extendedprice) AS avg_price, avg(l_discount) AS avg_disc, count(*) AS count_order FROM lineitem WHERE l_shipdate <= CAST(''1998-09-02'' AS date) GROUP BY l_returnflag, l_linestatus ORDER BY l_returnflag, l_linestatus;')
statement ok
CALL get_substrait('SELECT l_returnflag, l_linestatus, sum(l_quantity) AS sum_qty, sum(l_extendedprice) AS sum_base_price, sum(l_extendedprice * (1 - l_discount)) AS sum_disc_price, sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) AS sum_charge, avg(l_quantity) AS avg_qty, avg(l_extendedprice) AS avg_price, avg(l_discount) AS avg_disc, count(*) AS count_order FROM lineitem WHERE l_shipdate <= CAST(''1998-09-02'' AS date) GROUP BY l_returnflag, l_linestatus ORDER BY l_returnflag, l_linestatus;')

#Q 02 (DELIM_JOIN)
#statement ok
#CALL get_substrait('SELECT s_acctbal, s_name, n_name, p_partkey, p_mfgr, s_address, s_phone, s_comment FROM part, supplier, partsupp, nation, region WHERE p_partkey = ps_partkey AND s_suppkey = ps_suppkey AND p_size = 15 AND p_type LIKE ''%BRASS'' AND s_nationkey = n_nationkey AND n_regionkey = r_regionkey AND r_name = ''EUROPE'' AND ps_supplycost = ( SELECT min(ps_supplycost) FROM partsupp, supplier, nation, region WHERE p_partkey = ps_partkey AND s_suppkey = ps_suppkey AND s_nationkey = n_nationkey AND n_regionkey = r_regionkey AND r_name = ''EUROPE'') ORDER BY s_acctbal DESC, n_name, s_name, p_partkey LIMIT 100;')
#Q 02
statement ok
CALL get_substrait('SELECT s_acctbal, s_name, n_name, p_partkey, p_mfgr, s_address, s_phone, s_comment FROM part, supplier, partsupp, nation, region WHERE p_partkey = ps_partkey AND s_suppkey = ps_suppkey AND p_size = 15 AND p_type LIKE ''%BRASS'' AND s_nationkey = n_nationkey AND n_regionkey = r_regionkey AND r_name = ''EUROPE'' AND ps_supplycost = ( SELECT min(ps_supplycost) FROM partsupp, supplier, nation, region WHERE p_partkey = ps_partkey AND s_suppkey = ps_suppkey AND s_nationkey = n_nationkey AND n_regionkey = r_regionkey AND r_name = ''EUROPE'') ORDER BY s_acctbal DESC, n_name, s_name, p_partkey LIMIT 100;')

#Q 03
statement ok
CALL get_substrait('SELECT l_orderkey, sum(l_extendedprice * (1 - l_discount)) AS revenue, o_orderdate, o_shippriority FROM customer, orders, lineitem WHERE c_mktsegment = ''BUILDING'' AND c_custkey = o_custkey AND l_orderkey = o_orderkey AND o_orderdate < CAST(''1995-03-15'' AS date) AND l_shipdate > CAST(''1995-03-15'' AS date) GROUP BY l_orderkey, o_orderdate, o_shippriority ORDER BY revenue DESC, o_orderdate LIMIT 10;')

#Q 04 DELIM_JOIN
#statement ok
#CALL get_substrait('SELECT o_orderpriority, count(*) AS order_count FROM orders WHERE o_orderdate >= CAST(''1993-07-01'' AS date) AND o_orderdate < CAST(''1993-10-01'' AS date) AND EXISTS ( SELECT * FROM lineitem WHERE l_orderkey = o_orderkey AND l_commitdate < l_receiptdate) GROUP BY o_orderpriority ORDER BY o_orderpriority;')
#Q 04
statement ok
CALL get_substrait('SELECT o_orderpriority, count(*) AS order_count FROM orders WHERE o_orderdate >= CAST(''1993-07-01'' AS date) AND o_orderdate < CAST(''1993-10-01'' AS date) AND EXISTS ( SELECT * FROM lineitem WHERE l_orderkey = o_orderkey AND l_commitdate < l_receiptdate) GROUP BY o_orderpriority ORDER BY o_orderpriority;')

#Q 05
statement ok
Expand Down Expand Up @@ -82,9 +82,9 @@ CALL get_substrait('SELECT s_suppkey, s_name, s_address, s_phone, total_revenue
statement ok
CALL get_substrait('SELECT p_brand, p_type, p_size, count(DISTINCT ps_suppkey) AS supplier_cnt FROM partsupp, part WHERE p_partkey = ps_partkey AND p_brand <> ''Brand#45'' AND p_type NOT LIKE ''MEDIUM POLISHED%'' AND p_size IN (49, 14, 23, 45, 19, 3, 36, 9) AND ps_suppkey NOT IN ( SELECT s_suppkey FROM supplier WHERE s_comment LIKE ''%Customer%Complaints%'') GROUP BY p_brand, p_type, p_size ORDER BY supplier_cnt DESC, p_brand, p_type, p_size;')

#Q 17 (DELIM_JOIN)
#statement ok
#CALL get_substrait('SELECT sum(l_extendedprice) / 7.0 AS avg_yearly FROM lineitem, part WHERE p_partkey = l_partkey AND p_brand = ''Brand#23'' AND p_container = ''MED BOX'' AND l_quantity < ( SELECT 0.2 * avg(l_quantity) FROM lineitem WHERE l_partkey = p_partkey);')
#Q 17
statement ok
CALL get_substrait('SELECT sum(l_extendedprice) / 7.0 AS avg_yearly FROM lineitem, part WHERE p_partkey = l_partkey AND p_brand = ''Brand#23'' AND p_container = ''MED BOX'' AND l_quantity < ( SELECT 0.2 * avg(l_quantity) FROM lineitem WHERE l_partkey = p_partkey);')

#Q 18
statement ok
Expand Down
Loading

0 comments on commit a8529d1

Please sign in to comment.