diff --git a/src/from_substrait.cpp b/src/from_substrait.cpp index f995233..f3369a5 100644 --- a/src/from_substrait.cpp +++ b/src/from_substrait.cpp @@ -725,6 +725,19 @@ shared_ptr SubstraitToDuckDB::TransformSetOp(const substrait::Rel &sop return make_shared_ptr(std::move(lhs), std::move(rhs), type); } +OnCreateConflict SubstraitToDuckDB::TransformCreateMode(substrait::WriteRel_CreateMode mode) { + switch (mode) { + case substrait::WriteRel::CREATE_MODE_ERROR_IF_EXISTS: + return OnCreateConflict::ERROR_ON_CONFLICT; + case substrait::WriteRel::CREATE_MODE_IGNORE_IF_EXISTS: + return OnCreateConflict::IGNORE_ON_CONFLICT; + case substrait::WriteRel::CREATE_MODE_REPLACE_IF_EXISTS: + return OnCreateConflict::REPLACE_ON_CONFLICT; + default: + throw NotImplementedException("Unsupported on conflict type " + to_string(mode)); + } +} + shared_ptr SubstraitToDuckDB::TransformWriteOp(const substrait::Rel &sop) { auto &swrite = sop.write(); auto &nobj = swrite.named_table(); @@ -738,9 +751,13 @@ shared_ptr SubstraitToDuckDB::TransformWriteOp(const substrait::Rel &s schema_name = nobj.names(0); } auto input = TransformOp(swrite.input()); + auto on_conflict = OnCreateConflict::ERROR_ON_CONFLICT; + if (swrite.create_mode()) { + on_conflict = TransformCreateMode(swrite.create_mode()); + } switch (swrite.op()) { case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_CTAS: - return input->CreateRel(schema_name, table_name); + return input->CreateRel(schema_name, table_name, false, on_conflict); case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_INSERT: return input->InsertRel(schema_name, table_name); case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_DELETE: { @@ -841,7 +858,7 @@ shared_ptr SubstraitToDuckDB::TransformRootOp(const substrait::RelRoot case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_CTAS: { const auto create_table = static_cast(child.get()); auto proj = make_shared_ptr(create_table->child, std::move(expressions), aliases); - return proj->CreateRel(create_table->schema_name, create_table->table_name); + return proj->CreateRel(create_table->schema_name, create_table->table_name, create_table->temporary, create_table->on_conflict); } default: return child; diff --git a/src/include/from_substrait.hpp b/src/include/from_substrait.hpp index 7fbcfa4..2c8bd6f 100644 --- a/src/include/from_substrait.hpp +++ b/src/include/from_substrait.hpp @@ -80,6 +80,7 @@ class SubstraitToDuckDB { shared_ptr TransformSetOp(const substrait::Rel &sop, const google::protobuf::RepeatedPtrField *names = nullptr); shared_ptr TransformWriteOp(const substrait::Rel &sop); + static OnCreateConflict TransformCreateMode(substrait::WriteRel_CreateMode mode); //! Transform Substrait Expressions to DuckDB Expressions unique_ptr TransformExpr(const substrait::Expression &sexpr, diff --git a/src/include/to_substrait.hpp b/src/include/to_substrait.hpp index 7466395..113ba18 100644 --- a/src/include/to_substrait.hpp +++ b/src/include/to_substrait.hpp @@ -72,6 +72,7 @@ class DuckDBToSubstrait { substrait::Rel *TransformCreateTable(LogicalOperator &dop); substrait::Rel *TransformInsertTable(LogicalOperator &dop); substrait::Rel *TransformDeleteTable(LogicalOperator &dop); + static substrait::WriteRel_CreateMode TransformOnCreateConflict(OnCreateConflict on_conflict); static substrait::Rel *TransformDummyScan(); //! Methods to transform different LogicalGet Types (e.g., Table, Parquet) //! To Substrait; diff --git a/src/to_substrait.cpp b/src/to_substrait.cpp index 6ff0b0a..b9fa0d6 100644 --- a/src/to_substrait.cpp +++ b/src/to_substrait.cpp @@ -1437,6 +1437,19 @@ substrait::Rel *DuckDBToSubstrait::TransformIntersect(LogicalOperator &dop) { return rel; } +substrait::WriteRel_CreateMode DuckDBToSubstrait::TransformOnCreateConflict(OnCreateConflict on_conflict) { + switch(on_conflict) { + case OnCreateConflict::ERROR_ON_CONFLICT: + return substrait::WriteRel_CreateMode::WriteRel_CreateMode_CREATE_MODE_ERROR_IF_EXISTS; + case OnCreateConflict::IGNORE_ON_CONFLICT: + return substrait::WriteRel_CreateMode::WriteRel_CreateMode_CREATE_MODE_IGNORE_IF_EXISTS; + case OnCreateConflict::REPLACE_ON_CONFLICT: + return substrait::WriteRel_CreateMode::WriteRel_CreateMode_CREATE_MODE_REPLACE_IF_EXISTS; + default: + throw NotImplementedException("Unknown OnCreateConflict type " + to_string((int)on_conflict)); + } +} + substrait::Rel *DuckDBToSubstrait::TransformCreateTable(LogicalOperator &dop) { auto rel = new substrait::Rel(); auto &create_table = dop.Cast(); @@ -1468,7 +1481,7 @@ substrait::Rel *DuckDBToSubstrait::TransformCreateTable(LogicalOperator &dop) { auto named_table = write->mutable_named_table(); named_table->add_names(create_info.schema); named_table->add_names(create_info.table); - + write->set_create_mode(TransformOnCreateConflict(create_info.on_conflict)); return rel; } diff --git a/test/c/test_substrait_c_api.cpp b/test/c/test_substrait_c_api.cpp index 27ab432..3794704 100644 --- a/test/c/test_substrait_c_api.cpp +++ b/test/c/test_substrait_c_api.cpp @@ -320,3 +320,41 @@ TEST_CASE("Test C VirtualTable input Expression", "[substrait-api]") { REQUIRE(CHECK_COLUMN(result, 0, {2, 6})); REQUIRE(CHECK_COLUMN(result, 1, {4, 8})); } + +TEST_CASE("Test C CTAS with create_on_conflict via Substrait API", "[substrait-api]") { + DuckDB db(nullptr); + Connection con(db); + + CreateEmployeeTable(con); + + auto res1 = ExecuteViaSubstraitJSON(con, "CREATE TABLE employee_salaries AS " + "SELECT employee_id, salary FROM employees" + ); + + auto result = con.Query("SELECT * from employee_salaries"); + REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 3, 4, 5})); + REQUIRE(CHECK_COLUMN(result, 1, {120000, 80000, 50000, 95000, 60000})); + + + REQUIRE_NO_FAIL(ExecuteViaSubstraitJSON(con, "CREATE TABLE IF NOT EXISTS employee_salaries AS " + "SELECT employee_id, department_id, salary FROM employees")); + + result = con.Query("SELECT * from employee_salaries"); + REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 3, 4, 5})); + REQUIRE(CHECK_COLUMN(result, 1, {120000, 80000, 50000, 95000, 60000})); + + auto res3 = ExecuteViaSubstraitJSON(con, "CREATE TABLE employee_salaries AS " + "SELECT employee_id, department_id, salary FROM employees"); + REQUIRE_FAIL(res3); + result = con.Query("SELECT * from employee_salaries"); + REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 3, 4, 5})); + REQUIRE(CHECK_COLUMN(result, 1, {120000, 80000, 50000, 95000, 60000})); + + REQUIRE_NO_FAIL(ExecuteViaSubstraitJSON(con, "CREATE OR REPLACE TABLE employee_salaries AS " + "SELECT name, salary FROM employees")); + + result = con.Query("SELECT * from employee_salaries"); + REQUIRE(CHECK_COLUMN(result, 0, {"John Doe", "Jane Smith", "Alice Johnson", "Bob Brown", "Charlie Black"})); + REQUIRE(CHECK_COLUMN(result, 1, {120000, 80000, 50000, 95000, 60000})); +} +