Skip to content

Commit

Permalink
Merge pull request substrait-io#98 from pdet/virtual_table
Browse files Browse the repository at this point in the history
Implement Virtual Table support, fix warnings and clang-tidy and cleanup test groups.
  • Loading branch information
pdet authored Aug 5, 2024
2 parents 8e3e848 + 3688dd3 commit 69af93d
Show file tree
Hide file tree
Showing 14 changed files with 168 additions and 122 deletions.
110 changes: 52 additions & 58 deletions src/from_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,96 +84,79 @@ SubstraitToDuckDB::SubstraitToDuckDB(Connection &con_p, const string &serialized
}
}

unique_ptr<ParsedExpression> SubstraitToDuckDB::TransformLiteralExpr(const substrait::Expression &sexpr) {
const auto &slit = sexpr.literal();
Value dval;
if (slit.has_null()) {
dval = Value(LogicalType::SQLNULL);
return make_uniq<ConstantExpression>(dval);
Value TransformLiteralToValue(const substrait::Expression_Literal &literal) {
if (literal.has_null()) {
return Value(LogicalType::SQLNULL);
}
switch (slit.literal_type_case()) {
switch (literal.literal_type_case()) {
case substrait::Expression_Literal::LiteralTypeCase::kFp64:
dval = Value::DOUBLE(slit.fp64());
break;
return Value::DOUBLE(literal.fp64());
case substrait::Expression_Literal::LiteralTypeCase::kFp32:
dval = Value::FLOAT(slit.fp32());
break;
return Value::FLOAT(literal.fp32());
case substrait::Expression_Literal::LiteralTypeCase::kString:
dval = Value(slit.string());
break;
return {literal.string()};
case substrait::Expression_Literal::LiteralTypeCase::kDecimal: {
const auto &substrait_decimal = slit.decimal();
const auto &substrait_decimal = literal.decimal();
auto raw_value = (uint64_t *)substrait_decimal.value().c_str();
hugeint_t substrait_value;
hugeint_t substrait_value {};
substrait_value.lower = raw_value[0];
substrait_value.upper = raw_value[1];
Value val = Value::HUGEINT(substrait_value);
auto decimal_type = LogicalType::DECIMAL(substrait_decimal.precision(), substrait_decimal.scale());
// cast to correct value
switch (decimal_type.InternalType()) {
case PhysicalType::INT8:
dval = Value::DECIMAL(val.GetValue<int8_t>(), substrait_decimal.precision(), substrait_decimal.scale());
break;
return Value::DECIMAL(val.GetValue<int8_t>(), substrait_decimal.precision(), substrait_decimal.scale());
case PhysicalType::INT16:
dval = Value::DECIMAL(val.GetValue<int16_t>(), substrait_decimal.precision(), substrait_decimal.scale());
break;
return Value::DECIMAL(val.GetValue<int16_t>(), substrait_decimal.precision(), substrait_decimal.scale());
case PhysicalType::INT32:
dval = Value::DECIMAL(val.GetValue<int32_t>(), substrait_decimal.precision(), substrait_decimal.scale());
break;
return Value::DECIMAL(val.GetValue<int32_t>(), substrait_decimal.precision(), substrait_decimal.scale());
case PhysicalType::INT64:
dval = Value::DECIMAL(val.GetValue<int64_t>(), substrait_decimal.precision(), substrait_decimal.scale());
break;
return Value::DECIMAL(val.GetValue<int64_t>(), substrait_decimal.precision(), substrait_decimal.scale());
case PhysicalType::INT128:
dval = Value::DECIMAL(substrait_value, substrait_decimal.precision(), substrait_decimal.scale());
break;
return Value::DECIMAL(substrait_value, substrait_decimal.precision(), substrait_decimal.scale());
default:
throw InternalException("Not accepted internal type for decimal");
}
break;
}
case substrait::Expression_Literal::LiteralTypeCase::kBoolean: {
dval = Value(slit.boolean());
break;
return Value(literal.boolean());
}
case substrait::Expression_Literal::LiteralTypeCase::kI8:
dval = Value::TINYINT(slit.i8());
break;
return Value::TINYINT(literal.i8());
case substrait::Expression_Literal::LiteralTypeCase::kI32:
dval = Value::INTEGER(slit.i32());
break;
return Value::INTEGER(literal.i32());
case substrait::Expression_Literal::LiteralTypeCase::kI64:
dval = Value::BIGINT(slit.i64());
break;
return Value::BIGINT(literal.i64());
case substrait::Expression_Literal::LiteralTypeCase::kDate: {
date_t date(slit.date());
dval = Value::DATE(date);
break;
date_t date(literal.date());
return Value::DATE(date);
}
case substrait::Expression_Literal::LiteralTypeCase::kTime: {
dtime_t time(slit.time());
dval = Value::TIME(time);
break;
dtime_t time(literal.time());
return Value::TIME(time);
}
case substrait::Expression_Literal::LiteralTypeCase::kIntervalYearToMonth: {
interval_t interval;
interval.months = slit.interval_year_to_month().months();
interval_t interval {};
interval.months = literal.interval_year_to_month().months();
interval.days = 0;
interval.micros = 0;
dval = Value::INTERVAL(interval);
break;
return Value::INTERVAL(interval);
}
case substrait::Expression_Literal::LiteralTypeCase::kIntervalDayToSecond: {
interval_t interval;
interval_t interval {};
interval.months = 0;
interval.days = slit.interval_day_to_second().days();
interval.micros = slit.interval_day_to_second().microseconds();
dval = Value::INTERVAL(interval);
break;
interval.days = literal.interval_day_to_second().days();
interval.micros = literal.interval_day_to_second().microseconds();
return Value::INTERVAL(interval);
}
default:
throw InternalException(to_string(slit.literal_type_case()));
throw InternalException(to_string(literal.literal_type_case()));
}
return make_uniq<ConstantExpression>(dval);
}

unique_ptr<ParsedExpression> SubstraitToDuckDB::TransformLiteralExpr(const substrait::Expression &sexpr) {
return make_uniq<ConstantExpression>(TransformLiteralToValue(sexpr.literal()));
}

unique_ptr<ParsedExpression> SubstraitToDuckDB::TransformSelectionExpr(const substrait::Expression &sexpr) {
Expand Down Expand Up @@ -517,6 +500,19 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformReadOp(const substrait::Rel &so
string name = "parquet_" + StringUtil::GenerateRandomName();
named_parameter_map_t named_parameters({{"binary_as_string", Value::BOOLEAN(false)}});
scan = con.TableFunction("parquet_scan", {Value::LIST(parquet_files)}, named_parameters)->Alias(name);
} else if (sget.has_virtual_table()) {
// We need to handle a virtual table as a LogicalExpressionGet
auto literal_values = sget.virtual_table().values();
vector<vector<Value>> expression_rows;
for (auto &row : literal_values) {
auto values = row.fields();
vector<Value> expression_row;
for (const auto &value : values) {
expression_row.emplace_back(TransformLiteralToValue(value));
}
expression_rows.emplace_back(expression_row);
}
scan = con.Values(expression_rows);
} else {
throw NotImplementedException("Unsupported type of read operator for substrait");
}
Expand All @@ -535,7 +531,6 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformReadOp(const substrait::Rel &so
// TODO make sure nothing else is in there
expressions.push_back(make_uniq<PositionalReferenceExpression>(sproj.field() + 1));
}

scan = make_shared_ptr<ProjectionRelation>(std::move(scan), std::move(expressions), std::move(aliases));
}

Expand All @@ -550,20 +545,19 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformSortOp(const substrait::Rel &so
return make_shared_ptr<OrderRelation>(TransformOp(sop.sort().input()), std::move(order_nodes));
}

static duckdb::SetOperationType TransformSetOperationType(substrait::SetRel_SetOp setop) {
static SetOperationType TransformSetOperationType(substrait::SetRel_SetOp setop) {
switch (setop) {
case substrait::SetRel_SetOp::SetRel_SetOp_SET_OP_UNION_ALL: {
return duckdb::SetOperationType::UNION;
return SetOperationType::UNION;
}
case substrait::SetRel_SetOp::SetRel_SetOp_SET_OP_MINUS_PRIMARY: {
return duckdb::SetOperationType::EXCEPT;
return SetOperationType::EXCEPT;
}
case substrait::SetRel_SetOp::SetRel_SetOp_SET_OP_INTERSECTION_PRIMARY: {
return duckdb::SetOperationType::INTERSECT;
return SetOperationType::INTERSECT;
}
default: {
throw duckdb::NotImplementedException("SetOperationType transform not implemented for SetRel_SetOp type %d",
setop);
throw NotImplementedException("SetOperationType transform not implemented for SetRel_SetOp type %d", setop);
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/include/from_substrait.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

#include <string>
#include <unordered_map>
#include <memory>
#include "substrait/plan.pb.h"
#include "duckdb/main/connection.hpp"
#include "duckdb/common/shared_ptr.hpp"

namespace duckdb {

class SubstraitToDuckDB {
public:
SubstraitToDuckDB(Connection &con_p, const string &serialized, bool json = false);
Expand Down
90 changes: 45 additions & 45 deletions src/include/to_substrait.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
namespace duckdb {
class DuckDBToSubstrait {
public:
explicit DuckDBToSubstrait(ClientContext &context, duckdb::LogicalOperator &dop, bool strict_p)
explicit DuckDBToSubstrait(ClientContext &context, LogicalOperator &dop, bool strict_p)
: context(context), strict(strict_p) {
TransformPlan(dop);
};
Expand All @@ -30,7 +30,7 @@ class DuckDBToSubstrait {

private:
//! Transform DuckDB Plan to Substrait Plan
void TransformPlan(duckdb::LogicalOperator &dop);
void TransformPlan(LogicalOperator &dop);
//! Registers a function
uint64_t RegisterFunction(const std::string &name, vector<::substrait::Type> &args_types);
//! Creates a reference to a table column
Expand All @@ -40,76 +40,76 @@ class DuckDBToSubstrait {
substrait::RelRoot *TransformRootOp(LogicalOperator &dop);

//! Methods to Transform Logical Operators to Substrait Relations
substrait::Rel *TransformOp(duckdb::LogicalOperator &dop);
substrait::Rel *TransformFilter(duckdb::LogicalOperator &dop);
substrait::Rel *TransformProjection(duckdb::LogicalOperator &dop);
substrait::Rel *TransformTopN(duckdb::LogicalOperator &dop);
substrait::Rel *TransformLimit(duckdb::LogicalOperator &dop);
substrait::Rel *TransformOrderBy(duckdb::LogicalOperator &dop);
substrait::Rel *TransformComparisonJoin(duckdb::LogicalOperator &dop);
substrait::Rel *TransformAggregateGroup(duckdb::LogicalOperator &dop);
substrait::Rel *TransformGet(duckdb::LogicalOperator &dop);
substrait::Rel *TransformCrossProduct(duckdb::LogicalOperator &dop);
substrait::Rel *TransformUnion(duckdb::LogicalOperator &dop);
substrait::Rel *TransformDistinct(duckdb::LogicalOperator &dop);
substrait::Rel *TransformOp(LogicalOperator &dop);
substrait::Rel *TransformFilter(LogicalOperator &dop);
substrait::Rel *TransformProjection(LogicalOperator &dop);
substrait::Rel *TransformTopN(LogicalOperator &dop);
substrait::Rel *TransformLimit(LogicalOperator &dop);
substrait::Rel *TransformOrderBy(LogicalOperator &dop);
substrait::Rel *TransformComparisonJoin(LogicalOperator &dop);
substrait::Rel *TransformAggregateGroup(LogicalOperator &dop);
substrait::Rel *TransformGet(LogicalOperator &dop);
substrait::Rel *TransformCrossProduct(LogicalOperator &dop);
substrait::Rel *TransformUnion(LogicalOperator &dop);
substrait::Rel *TransformDistinct(LogicalOperator &dop);
substrait::Rel *TransformExcept(LogicalOperator &dop);
substrait::Rel *TransformIntersect(LogicalOperator &dop);

substrait::Rel *TransformDummyScan();
//! Methods to transform different LogicalGet Types (e.g., Table, Parquet)
//! To Substrait;
void TransformTableScanToSubstrait(LogicalGet &dget, substrait::ReadRel *sget);
void TransformParquetScanToSubstrait(LogicalGet &dget, substrait::ReadRel *sget, BindInfo &bind_info,
FunctionData &bind_data);

//! Methods to transform DuckDBConstants to Substrait Expressions
void TransformConstant(duckdb::Value &dval, substrait::Expression &sexpr);
void TransformInteger(duckdb::Value &dval, substrait::Expression &sexpr);
void TransformConstant(Value &dval, substrait::Expression &sexpr);
void TransformInteger(Value &dval, substrait::Expression &sexpr);
void TransformDouble(Value &dval, substrait::Expression &sexpr);
void TransformBigInt(duckdb::Value &dval, substrait::Expression &sexpr);
void TransformDate(duckdb::Value &dval, substrait::Expression &sexpr);
void TransformVarchar(duckdb::Value &dval, substrait::Expression &sexpr);
void TransformBoolean(duckdb::Value &dval, substrait::Expression &sexpr);
void TransformDecimal(duckdb::Value &dval, substrait::Expression &sexpr);
void TransformBigInt(Value &dval, substrait::Expression &sexpr);
void TransformDate(Value &dval, substrait::Expression &sexpr);
void TransformVarchar(Value &dval, substrait::Expression &sexpr);
void TransformBoolean(Value &dval, substrait::Expression &sexpr);
void TransformDecimal(Value &dval, substrait::Expression &sexpr);
void TransformHugeInt(Value &dval, substrait::Expression &sexpr);
void TransformSmallInt(duckdb::Value &dval, substrait::Expression &sexpr);
void TransformSmallInt(Value &dval, substrait::Expression &sexpr);
void TransformFloat(Value &dval, substrait::Expression &sexpr);
void TransformTime(Value &dval, substrait::Expression &sexpr);
void TransformInterval(Value &dval, substrait::Expression &sexpr);
void TransformTimestamp(Value &dval, substrait::Expression &sexpr);
void TransformEnum(duckdb::Value &dval, substrait::Expression &sexpr);
void TransformEnum(Value &dval, substrait::Expression &sexpr);

//! Methods to transform a DuckDB Expression to a Substrait Expression
void TransformExpr(duckdb::Expression &dexpr, substrait::Expression &sexpr, uint64_t col_offset = 0);
void TransformBoundRefExpression(duckdb::Expression &dexpr, substrait::Expression &sexpr, uint64_t col_offset);
void TransformCastExpression(duckdb::Expression &dexpr, substrait::Expression &sexpr, uint64_t col_offset);
void TransformFunctionExpression(duckdb::Expression &dexpr, substrait::Expression &sexpr, uint64_t col_offset);
void TransformConstantExpression(duckdb::Expression &dexpr, substrait::Expression &sexpr);
void TransformComparisonExpression(duckdb::Expression &dexpr, substrait::Expression &sexpr);
void TransformConjunctionExpression(duckdb::Expression &dexpr, substrait::Expression &sexpr, uint64_t col_offset);
void TransformNotNullExpression(duckdb::Expression &dexpr, substrait::Expression &sexpr, uint64_t col_offset);
void TransformIsNullExpression(duckdb::Expression &dexpr, substrait::Expression &sexpr, uint64_t col_offset);
void TransformNotExpression(duckdb::Expression &dexpr, substrait::Expression &sexpr, uint64_t col_offset);
void TransformCaseExpression(duckdb::Expression &dexpr, substrait::Expression &sexpr);
void TransformInExpression(duckdb::Expression &dexpr, substrait::Expression &sexpr);
void TransformExpr(Expression &dexpr, substrait::Expression &sexpr, uint64_t col_offset = 0);
void TransformBoundRefExpression(Expression &dexpr, substrait::Expression &sexpr, uint64_t col_offset);
void TransformCastExpression(Expression &dexpr, substrait::Expression &sexpr, uint64_t col_offset);
void TransformFunctionExpression(Expression &dexpr, substrait::Expression &sexpr, uint64_t col_offset);
void TransformConstantExpression(Expression &dexpr, substrait::Expression &sexpr);
void TransformComparisonExpression(Expression &dexpr, substrait::Expression &sexpr);
void TransformConjunctionExpression(Expression &dexpr, substrait::Expression &sexpr, uint64_t col_offset);
void TransformNotNullExpression(Expression &dexpr, substrait::Expression &sexpr, uint64_t col_offset);
void TransformIsNullExpression(Expression &dexpr, substrait::Expression &sexpr, uint64_t col_offset);
void TransformNotExpression(Expression &dexpr, substrait::Expression &sexpr, uint64_t col_offset);
void TransformCaseExpression(Expression &dexpr, substrait::Expression &sexpr);
void TransformInExpression(Expression &dexpr, substrait::Expression &sexpr);

//! Transforms a DuckDB Logical Type into a Substrait Type
::substrait::Type DuckToSubstraitType(const LogicalType &type, BaseStatistics *column_statistics = nullptr,
bool not_null = false);

//! Methods to transform DuckDB Filters to Substrait Expression
substrait::Expression *TransformFilter(uint64_t col_idx, LogicalType &column_type, duckdb::TableFilter &dfilter,
substrait::Expression *TransformFilter(uint64_t col_idx, LogicalType &column_type, TableFilter &dfilter,
LogicalType &return_type);
substrait::Expression *TransformIsNotNullFilter(uint64_t col_idx, LogicalType &column_type,
duckdb::TableFilter &dfilter, LogicalType &return_type);
substrait::Expression *TransformIsNotNullFilter(uint64_t col_idx, LogicalType &column_type, TableFilter &dfilter,
LogicalType &return_type);
substrait::Expression *TransformConjuctionAndFilter(uint64_t col_idx, LogicalType &column_type,
duckdb::TableFilter &dfilter, LogicalType &return_type);
TableFilter &dfilter, LogicalType &return_type);
substrait::Expression *TransformConstantComparisonFilter(uint64_t col_idx, LogicalType &column_type,
duckdb::TableFilter &dfilter, LogicalType &return_type);
TableFilter &dfilter, LogicalType &return_type);

//! Transforms DuckDB Join Conditions to Substrait Expression
substrait::Expression *TransformJoinCond(duckdb::JoinCondition &dcond, uint64_t left_ncol);
substrait::Expression *TransformJoinCond(JoinCondition &dcond, uint64_t left_ncol);
//! Transforms DuckDB Sort Order to Substrait Sort Order
void TransformOrder(duckdb::BoundOrderByNode &dordf, substrait::SortField &sordf);
void TransformOrder(BoundOrderByNode &dordf, substrait::SortField &sordf);

void AllocateFunctionArgument(substrait::Expression_ScalarFunction *scalar_fun, substrait::Expression *value);
static std::string &RemapFunctionName(std::string &function_name);
Expand Down Expand Up @@ -143,8 +143,8 @@ class DuckDBToSubstrait {
}

//! Variables used to register functions
std::unordered_map<std::string, uint64_t> functions_map;
std::unordered_map<std::string, uint64_t> extension_uri_map;
unordered_map<string, uint64_t> functions_map;
unordered_map<string, uint64_t> extension_uri_map;

//! Remapped DuckDB functions names to Substrait compatible function names
static const unordered_map<std::string, std::string> function_names_remap;
Expand Down
2 changes: 1 addition & 1 deletion src/substrait_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ static void VerifyBlobRoundtrip(unique_ptr<LogicalOperator> &query_plan, Connect
const string &serialized);

static void SetOptions(ToSubstraitFunctionData &function, const ClientConfig &config,
const duckdb::named_parameter_map_t &named_params) {
const named_parameter_map_t &named_params) {
bool optimizer_option_set = false;
for (const auto &param : named_params) {
auto loption = StringUtil::Lower(param.first);
Expand Down
Loading

0 comments on commit 69af93d

Please sign in to comment.