Skip to content

Commit

Permalink
feat: Support create_on_conflict in CTAS
Browse files Browse the repository at this point in the history
  • Loading branch information
scgkiran committed Dec 9, 2024
1 parent abc4b70 commit 05a0d49
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 3 deletions.
21 changes: 19 additions & 2 deletions src/from_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,19 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformSetOp(const substrait::Rel &sop
return make_shared_ptr<SetOpRelation>(std::move(lhs), std::move(rhs), type);
}

OnCreateConflict SubstraitToDuckDB::TransformCreateMode(substrait::WriteRel_CreateMode mode) {
switch (mode) {
case substrait::WriteRel::CREATE_MODE_ERROR_IF_EXISTS:
return OnCreateConflict::ERROR_ON_CONFLICT;
case substrait::WriteRel::CREATE_MODE_IGNORE_IF_EXISTS:
return OnCreateConflict::IGNORE_ON_CONFLICT;
case substrait::WriteRel::CREATE_MODE_REPLACE_IF_EXISTS:
return OnCreateConflict::REPLACE_ON_CONFLICT;
default:
throw NotImplementedException("Unsupported on conflict type " + to_string(mode));
}
}

shared_ptr<Relation> SubstraitToDuckDB::TransformWriteOp(const substrait::Rel &sop) {
auto &swrite = sop.write();
auto &nobj = swrite.named_table();
Expand All @@ -738,9 +751,13 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformWriteOp(const substrait::Rel &s
schema_name = nobj.names(0);
}
auto input = TransformOp(swrite.input());
auto on_conflict = OnCreateConflict::ERROR_ON_CONFLICT;
if (swrite.create_mode()) {
on_conflict = TransformCreateMode(swrite.create_mode());
}
switch (swrite.op()) {
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_CTAS:
return input->CreateRel(schema_name, table_name);
return input->CreateRel(schema_name, table_name, false, on_conflict);

Check failure on line 760 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

too many arguments to function call, expected at most 3, have 4

Check failure on line 760 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

too many arguments to function call, expected at most 3, have 4
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_INSERT:
return input->InsertRel(schema_name, table_name);
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_DELETE: {
Expand Down Expand Up @@ -841,7 +858,7 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformRootOp(const substrait::RelRoot
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_CTAS: {
const auto create_table = static_cast<CreateTableRelation *>(child.get());
auto proj = make_shared_ptr<ProjectionRelation>(create_table->child, std::move(expressions), aliases);
return proj->CreateRel(create_table->schema_name, create_table->table_name);
return proj->CreateRel(create_table->schema_name, create_table->table_name, create_table->temporary, create_table->on_conflict);

Check failure on line 861 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

no member named 'on_conflict' in 'duckdb::CreateTableRelation'

Check failure on line 861 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

no member named 'on_conflict' in 'duckdb::CreateTableRelation'
}
default:
return child;
Expand Down
1 change: 1 addition & 0 deletions src/include/from_substrait.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class SubstraitToDuckDB {
shared_ptr<Relation> TransformSetOp(const substrait::Rel &sop,
const google::protobuf::RepeatedPtrField<std::string> *names = nullptr);
shared_ptr<Relation> TransformWriteOp(const substrait::Rel &sop);
static OnCreateConflict TransformCreateMode(substrait::WriteRel_CreateMode mode);

//! Transform Substrait Expressions to DuckDB Expressions
unique_ptr<ParsedExpression> TransformExpr(const substrait::Expression &sexpr,
Expand Down
1 change: 1 addition & 0 deletions src/include/to_substrait.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class DuckDBToSubstrait {
substrait::Rel *TransformCreateTable(LogicalOperator &dop);
substrait::Rel *TransformInsertTable(LogicalOperator &dop);
substrait::Rel *TransformDeleteTable(LogicalOperator &dop);
static substrait::WriteRel_CreateMode TransformOnCreateConflict(OnCreateConflict on_conflict);
static substrait::Rel *TransformDummyScan();
//! Methods to transform different LogicalGet Types (e.g., Table, Parquet)
//! To Substrait;
Expand Down
15 changes: 14 additions & 1 deletion src/to_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1437,6 +1437,19 @@ substrait::Rel *DuckDBToSubstrait::TransformIntersect(LogicalOperator &dop) {
return rel;
}

substrait::WriteRel_CreateMode DuckDBToSubstrait::TransformOnCreateConflict(OnCreateConflict on_conflict) {
switch(on_conflict) {
case OnCreateConflict::ERROR_ON_CONFLICT:
return substrait::WriteRel_CreateMode::WriteRel_CreateMode_CREATE_MODE_ERROR_IF_EXISTS;
case OnCreateConflict::IGNORE_ON_CONFLICT:
return substrait::WriteRel_CreateMode::WriteRel_CreateMode_CREATE_MODE_IGNORE_IF_EXISTS;
case OnCreateConflict::REPLACE_ON_CONFLICT:
return substrait::WriteRel_CreateMode::WriteRel_CreateMode_CREATE_MODE_REPLACE_IF_EXISTS;
default:
throw NotImplementedException("Unknown OnCreateConflict type " + to_string((int)on_conflict));
}
}

substrait::Rel *DuckDBToSubstrait::TransformCreateTable(LogicalOperator &dop) {
auto rel = new substrait::Rel();
auto &create_table = dop.Cast<LogicalCreateTable>();
Expand Down Expand Up @@ -1468,7 +1481,7 @@ substrait::Rel *DuckDBToSubstrait::TransformCreateTable(LogicalOperator &dop) {
auto named_table = write->mutable_named_table();
named_table->add_names(create_info.schema);
named_table->add_names(create_info.table);

write->set_create_mode(TransformOnCreateConflict(create_info.on_conflict));
return rel;
}

Expand Down
38 changes: 38 additions & 0 deletions test/c/test_substrait_c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,3 +320,41 @@ TEST_CASE("Test C VirtualTable input Expression", "[substrait-api]") {
REQUIRE(CHECK_COLUMN(result, 0, {2, 6}));
REQUIRE(CHECK_COLUMN(result, 1, {4, 8}));
}

TEST_CASE("Test C CTAS with create_on_conflict via Substrait API", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);

CreateEmployeeTable(con);

auto res1 = ExecuteViaSubstraitJSON(con, "CREATE TABLE employee_salaries AS "
"SELECT employee_id, salary FROM employees"
);

auto result = con.Query("SELECT * from employee_salaries");
REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 3, 4, 5}));
REQUIRE(CHECK_COLUMN(result, 1, {120000, 80000, 50000, 95000, 60000}));


REQUIRE_NO_FAIL(ExecuteViaSubstraitJSON(con, "CREATE TABLE IF NOT EXISTS employee_salaries AS "
"SELECT employee_id, department_id, salary FROM employees"));

result = con.Query("SELECT * from employee_salaries");
REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 3, 4, 5}));
REQUIRE(CHECK_COLUMN(result, 1, {120000, 80000, 50000, 95000, 60000}));

auto res3 = ExecuteViaSubstraitJSON(con, "CREATE TABLE employee_salaries AS "
"SELECT employee_id, department_id, salary FROM employees");
REQUIRE_FAIL(res3);
result = con.Query("SELECT * from employee_salaries");
REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 3, 4, 5}));
REQUIRE(CHECK_COLUMN(result, 1, {120000, 80000, 50000, 95000, 60000}));

REQUIRE_NO_FAIL(ExecuteViaSubstraitJSON(con, "CREATE OR REPLACE TABLE employee_salaries AS "
"SELECT name, salary FROM employees"));

result = con.Query("SELECT * from employee_salaries");
REQUIRE(CHECK_COLUMN(result, 0, {"John Doe", "Jane Smith", "Alice Johnson", "Bob Brown", "Charlie Black"}));
REQUIRE(CHECK_COLUMN(result, 1, {120000, 80000, 50000, 95000, 60000}));
}

0 comments on commit 05a0d49

Please sign in to comment.