Skip to content

Commit

Permalink
Merge branch 'view' into structs
Browse files Browse the repository at this point in the history
  • Loading branch information
pdet committed Nov 5, 2024
2 parents d488993 + bac2ff3 commit 0a21b92
Show file tree
Hide file tree
Showing 11 changed files with 610 additions and 21 deletions.
1 change: 1 addition & 0 deletions .github/workflows/main_distribution.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ jobs:
uses: duckdb/extension-ci-tools/.github/workflows/_extension_distribution.yml@main
with:
duckdb_version: main
ci_tools_version: main
exclude_archs: "wasm_mvp;wasm_eh;wasm_threads;windows_amd64;windows_amd64_rtools"
extension_name: substrait

2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ all: release
EXT_NAME=substrait
EXT_CONFIG=${PROJ_DIR}extension_config.cmake

CORE_EXTENSIONS='tpch;json'
CORE_EXTENSIONS='tpch;tpcds;json'

# Set this flag during building to enable the benchmark runner
ifeq (${BUILD_BENCHMARK}, 1)
Expand Down
2 changes: 1 addition & 1 deletion duckdb
Submodule duckdb updated 1863 files
2 changes: 1 addition & 1 deletion duckdb-r
Submodule duckdb-r updated 640 files
39 changes: 29 additions & 10 deletions src/from_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,6 @@ unique_ptr<ParsedExpression> SubstraitToDuckDB::TransformScalarFunctionExpr(cons
return make_uniq<ComparisonExpression>(ExpressionType::COMPARE_NOT_DISTINCT_FROM, std::move(children[0]),
std::move(children[1]));
} else if (function_name == "between") {
// FIXME: ADD between to substrait extension
D_ASSERT(children.size() == 3);
return make_uniq<BetweenExpression>(std::move(children[0]), std::move(children[1]), std::move(children[2]));
} else if (function_name == "extract") {
Expand Down Expand Up @@ -541,17 +540,14 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformAggregateOp(const substrait::Re
std::move(groups));
}
unique_ptr<TableDescription> TableInfo(ClientContext &context, const string &schema_name, const string &table_name) {
unique_ptr<TableDescription> result;
// obtain the table info
auto table = Catalog::GetEntry<TableCatalogEntry>(context, INVALID_CATALOG, schema_name, table_name,
OnEntryNotFound::RETURN_NULL);
if (!table) {
return {};
}
// write the table info to the result
result = make_uniq<TableDescription>();
result->schema = schema_name;
result->table = table_name;
auto result = make_uniq<TableDescription>(INVALID_CATALOG, schema_name, table_name);
for (auto &column : table->GetColumns().Logical()) {
result->columns.emplace_back(column.Copy());
}
Expand All @@ -561,6 +557,7 @@ unique_ptr<TableDescription> TableInfo(ClientContext &context, const string &sch
shared_ptr<Relation> SubstraitToDuckDB::TransformReadOp(const substrait::Rel &sop) {
auto &sget = sop.read();
shared_ptr<Relation> scan;
auto context_wrapper = make_shared_ptr<RelationContextWrapper>(context);

Check failure on line 560 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

unknown type name 'RelationContextWrapper'; did you mean 'ClientContextWrapper'?

Check failure on line 560 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / Linux (linux_amd64_gcc4, quay.io/pypa/manylinux2014_x86_64, x64-linux)

‘RelationContextWrapper’ was not declared in this scope; did you mean ‘ClientContextWrapper’?

Check failure on line 560 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / Linux (linux_amd64_gcc4, quay.io/pypa/manylinux2014_x86_64, x64-linux)

no matching function for call to ‘make_shared_ptr<<expression error> >(duckdb::shared_ptr<duckdb::ClientContext, true>&)’

Check failure on line 560 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / Linux (linux_amd64_gcc4, quay.io/pypa/manylinux2014_x86_64, x64-linux)

template argument 1 is invalid
if (sget.has_named_table()) {
auto table_name = sget.named_table().names(0);
// If we can't find a table with that name, let's try a view.
Expand All @@ -569,9 +566,19 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformReadOp(const substrait::Rel &so
if (!table_info) {
throw CatalogException("Table '%s' does not exist!", table_name);
}
scan = make_shared_ptr<TableRelation>(context, std::move(table_info), acquire_lock);
if (acquire_lock) {
scan = make_shared_ptr<TableRelation>(context, std::move(table_info));

} else {
scan = make_shared_ptr<TableRelation>(context_wrapper, std::move(table_info));
}
} catch (...) {
scan = make_shared_ptr<ViewRelation>(context, DEFAULT_SCHEMA, table_name, acquire_lock);
if (acquire_lock) {
scan = make_shared_ptr<ViewRelation>(context, DEFAULT_SCHEMA, table_name);

} else {
scan = make_shared_ptr<ViewRelation>(context_wrapper, DEFAULT_SCHEMA, table_name);
}
}
} else if (sget.has_local_files()) {
vector<Value> parquet_files;
Expand All @@ -593,8 +600,15 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformReadOp(const substrait::Rel &so
string name = "parquet_" + StringUtil::GenerateRandomName();
named_parameter_map_t named_parameters({{"binary_as_string", Value::BOOLEAN(false)}});
vector<Value> parameters {Value::LIST(parquet_files)};
auto scan_rel = make_shared_ptr<TableFunctionRelation>(
context, "parquet_scan", parameters, std::move(named_parameters), nullptr, true, acquire_lock);
shared_ptr<TableFunctionRelation> scan_rel;
if (acquire_lock) {
scan_rel = make_shared_ptr<TableFunctionRelation>(context, "parquet_scan", parameters,
std::move(named_parameters));
} else {
scan_rel = make_shared_ptr<TableFunctionRelation>(context_wrapper, "parquet_scan", parameters,
std::move(named_parameters));
}

auto rel = static_cast<Relation *>(scan_rel.get());
scan = rel->Alias(name);
} else if (sget.has_virtual_table()) {
Expand All @@ -610,7 +624,12 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformReadOp(const substrait::Rel &so
expression_rows.emplace_back(expression_row);
}
vector<string> column_names;
scan = make_shared_ptr<ValueRelation>(context, expression_rows, column_names, "values", acquire_lock);
if (acquire_lock) {
scan = make_shared_ptr<ValueRelation>(context, expression_rows, column_names);

} else {
scan = make_shared_ptr<ValueRelation>(context_wrapper, expression_rows, column_names);
}
} else {
throw NotImplementedException("Unsupported type of read operator for substrait");
}
Expand Down
1 change: 1 addition & 0 deletions src/include/to_substrait.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class DuckDBToSubstrait {
void TransformFunctionExpression(Expression &dexpr, substrait::Expression &sexpr, uint64_t col_offset);
static void TransformConstantExpression(Expression &dexpr, substrait::Expression &sexpr);
void TransformComparisonExpression(Expression &dexpr, substrait::Expression &sexpr);
void TransformBetweenExpression(Expression &dexpr, substrait::Expression &sexpr);
void TransformConjunctionExpression(Expression &dexpr, substrait::Expression &sexpr, uint64_t col_offset);
void TransformNotNullExpression(Expression &dexpr, substrait::Expression &sexpr, uint64_t col_offset);
void TransformIsNullExpression(Expression &dexpr, substrait::Expression &sexpr, uint64_t col_offset);
Expand Down
38 changes: 32 additions & 6 deletions src/to_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,29 @@ void DuckDBToSubstrait::TransformComparisonExpression(Expression &dexpr, substra
*scalar_fun->mutable_output_type() = DuckToSubstraitType(dcomp.return_type);
}

void DuckDBToSubstrait::TransformBetweenExpression(Expression &dexpr, substrait::Expression &sexpr) {
auto &dcomp = dexpr.Cast<BoundBetweenExpression>();

if (dexpr.type != ExpressionType::COMPARE_BETWEEN) {
throw InternalException("Not a between comparison expression");
}

auto scalar_fun = sexpr.mutable_scalar_function();
vector<::substrait::Type> args_types;
args_types.emplace_back(DuckToSubstraitType(dcomp.input->return_type));
args_types.emplace_back(DuckToSubstraitType(dcomp.lower->return_type));
args_types.emplace_back(DuckToSubstraitType(dcomp.upper->return_type));
scalar_fun->set_function_reference(RegisterFunction("between", args_types));

auto sarg = scalar_fun->add_arguments();
TransformExpr(*dcomp.input, *sarg->mutable_value(), 0);
sarg = scalar_fun->add_arguments();
TransformExpr(*dcomp.lower, *sarg->mutable_value(), 0);
sarg = scalar_fun->add_arguments();
TransformExpr(*dcomp.upper, *sarg->mutable_value(), 0);
*scalar_fun->mutable_output_type() = DuckToSubstraitType(dcomp.return_type);
}

void DuckDBToSubstrait::TransformConjunctionExpression(Expression &dexpr, substrait::Expression &sexpr,
uint64_t col_offset) {
auto &dconj = dexpr.Cast<BoundConjunctionExpression>();
Expand Down Expand Up @@ -537,6 +560,9 @@ void DuckDBToSubstrait::TransformExpr(Expression &dexpr, substrait::Expression &
case ExpressionType::COMPARE_NOT_DISTINCT_FROM:
TransformComparisonExpression(dexpr, sexpr);
break;
case ExpressionType::COMPARE_BETWEEN:
TransformBetweenExpression(dexpr, sexpr);
break;
case ExpressionType::CONJUNCTION_AND:
case ExpressionType::CONJUNCTION_OR:
TransformConjunctionExpression(dexpr, sexpr, col_offset);
Expand All @@ -557,7 +583,7 @@ void DuckDBToSubstrait::TransformExpr(Expression &dexpr, substrait::Expression &
TransformNotExpression(dexpr, sexpr, col_offset);
break;
default:
throw InternalException(ExpressionTypeToString(dexpr.type));
throw NotImplementedException(ExpressionTypeToString(dexpr.type));
}
}

Expand Down Expand Up @@ -742,7 +768,7 @@ substrait::Expression *DuckDBToSubstrait::TransformJoinCond(const JoinCondition
join_comparision = "lt";
break;
default:
throw InternalException("Unsupported join comparison: " + ExpressionTypeToOperator(dcond.comparison));
throw NotImplementedException("Unsupported join comparison: " + ExpressionTypeToOperator(dcond.comparison));
}
vector<::substrait::Type> args_types;
auto scalar_fun = expr->mutable_scalar_function();
Expand Down Expand Up @@ -946,7 +972,7 @@ substrait::Rel *DuckDBToSubstrait::TransformComparisonJoin(LogicalOperator &dop)
sjoin->set_type(substrait::JoinRel::JoinType::JoinRel_JoinType_JOIN_TYPE_OUTER);
break;
default:
throw InternalException("Unsupported join type " + JoinTypeToString(djoin.join_type));
throw NotImplementedException("Unsupported join type " + JoinTypeToString(djoin.join_type));
}
// somewhat odd semantics on our side
if (djoin.left_projection_map.empty()) {
Expand Down Expand Up @@ -984,15 +1010,15 @@ substrait::Rel *DuckDBToSubstrait::TransformAggregateGroup(LogicalOperator &dop)
for (auto &dgrp : daggr.groups) {
if (dgrp->type != ExpressionType::BOUND_REF) {
// TODO push projection or push substrait to allow expressions here
throw InternalException("No expressions in groupings yet");
throw NotImplementedException("No expressions in groupings yet");
}
TransformExpr(*dgrp, *sgrp->add_grouping_expressions());
}
for (auto &dmeas : daggr.expressions) {
auto smeas = saggr->add_measures()->mutable_measure();
if (dmeas->type != ExpressionType::BOUND_AGGREGATE) {
// TODO push projection or push substrait, too
throw InternalException("No non-aggregate expressions in measures yet");
throw NotImplementedException("No non-aggregate expressions in measures yet");
}
auto &daexpr = dmeas->Cast<BoundAggregateExpression>();

Expand Down Expand Up @@ -1422,7 +1448,7 @@ substrait::Rel *DuckDBToSubstrait::TransformOp(LogicalOperator &dop) {
case LogicalOperatorType::LOGICAL_DUMMY_SCAN:
return TransformDummyScan();
default:
throw InternalException(LogicalOperatorToString(dop.type));
throw NotImplementedException(LogicalOperatorToString(dop.type));
}
}

Expand Down
14 changes: 14 additions & 0 deletions test/sql/test_between.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# name: test/sql/test_between.test
# description: Test BETWEEN comparison
# group: [sql]

require substrait

statement ok
PRAGMA enable_verification

statement ok
create table t as select * from range(100) as t(x)

statement ok
CALL get_substrait('select * from t where x BETWEEN 4 AND 6;', enable_optimizer = false );
Loading

0 comments on commit 0a21b92

Please sign in to comment.