From 16127f81b6f05fcad810066e7c8f6433f4f47a6f Mon Sep 17 00:00:00 2001 From: Bruno Tremblay Date: Thu, 9 Jan 2025 19:49:47 -0500 Subject: [PATCH] rebase --- R/progress_display.R | 31 ++++++++++++++++ src/connection.cpp | 50 ++++++++++++++++++++++++++ src/include/r_progress_bar_display.hpp | 28 +++++++++++++++ src/include/rapi.hpp | 1 + src/utils.cpp | 1 + tests/testthat/test-progress-display.R | 11 ++++++ 6 files changed, 122 insertions(+) create mode 100644 R/progress_display.R create mode 100644 src/include/r_progress_bar_display.hpp create mode 100644 tests/testthat/test-progress-display.R diff --git a/R/progress_display.R b/R/progress_display.R new file mode 100644 index 000000000..d14c2b7ff --- /dev/null +++ b/R/progress_display.R @@ -0,0 +1,31 @@ +duckdb_progress_display <- function(x) { + if (x <= 100) cat(sprintf("\r%3d%%", x)) + if (x >= 100) cat("\r ") +} + +check_progress_display <- function(f) { + if (is.null(f)) return() + if (is.TRUE(f) || isFALSE(f)) return() + if (is.function(f)) { + if (length(formals(f)) > 0) return() + stop("`progress_display` has no argument, expecting at least one.") + } + stop("`progress_display` is not function, expecting either a boolean or function.") +} + +set_progress_display <- function(f) { + check_progress_display(f) + options("duckdb.progress_display" = { + if (isTRUE(progress_display)) { + duckdb_progress_display + } else if (is.function(progress_display)) { + progress_display + } else { + NULL + } + }) +} + +get_progress_display <- function(f) { + getOption("duckdb.progress_display", default = duckdb_progress_display) +} diff --git a/src/connection.cpp b/src/connection.cpp index 184a5c2fd..71922ca89 100644 --- a/src/connection.cpp +++ b/src/connection.cpp @@ -1,4 +1,6 @@ #include "rapi.hpp" +#include "r_progress_bar_display.hpp" +#include "duckdb/main/client_context.hpp" using namespace duckdb; @@ -7,6 +9,50 @@ void duckdb::ConnDeleter(ConnWrapper *conn) { delete conn; } +unique_ptr RProgressBarDisplay::Create() { + return make_uniq(); +} + +void RProgressBarDisplay::Initialize() { + auto progress_display = Rf_GetOption(RStrings::get().progress_display_sym, R_BaseEnv); + if (Rf_isFunction(progress_display)) { + progress_callback = progress_display; + } + D_ASSERT(progress_callback != R_NilValue); +} + +RProgressBarDisplay::RProgressBarDisplay() : ProgressBarDisplay() { + // Empty +} + +void RProgressBarDisplay::Update(double percentage) { + if (progress_callback == R_NilValue) { + Initialize(); + } + if (progress_callback != R_NilValue) { + try { + cpp11::sexp call = Rf_lang2(progress_callback, Rf_ScalarReal(percentage)); + cpp11::safe[Rf_eval](call, R_BaseEnv); + } catch (std::exception &e) { + // Ignore progress bar error + } + } +} + +void RProgressBarDisplay::Finish() { + Update(100); +} + +static void SetDefaultConfigArguments(ClientContext &context) { + auto &config = ClientConfig::GetConfig(context); + // Set the function used to create the display for the progress bar + config.display_create_func = RProgressBarDisplay::Create; + auto progress_display = Rf_GetOption(RStrings::get().progress_display_sym, R_BaseEnv); + if (Rf_isFunction(progress_display)) { + config.enable_progress_bar = true; + } +} + [[cpp11::register]] duckdb::conn_eptr_t rapi_connect(duckdb::db_eptr_t dual) { if (!dual || !dual.get()) { cpp11::stop("rapi_connect: Invalid database reference"); @@ -20,6 +66,10 @@ void duckdb::ConnDeleter(ConnWrapper *conn) { conn_wrapper->conn = make_uniq(*db->db); conn_wrapper->db.swap(db); + // Set progress display config + auto &client_context = *conn_wrapper->conn->context; + SetDefaultConfigArguments(client_context); + // The connection now holds a reference to the database. // This reference is released when the connection is closed. // From the R side, the database pointer will remain valid diff --git a/src/include/r_progress_bar_display.hpp b/src/include/r_progress_bar_display.hpp new file mode 100644 index 000000000..6fd867daa --- /dev/null +++ b/src/include/r_progress_bar_display.hpp @@ -0,0 +1,28 @@ +#pragma once + +#include "rapi.hpp" +#include "duckdb/common/progress_bar/progress_bar_display.hpp" +#include "duckdb/common/helper.hpp" + +namespace duckdb { + +class RProgressBarDisplay : public ProgressBarDisplay { +public: + RProgressBarDisplay(); + virtual ~RProgressBarDisplay() { + } + + static unique_ptr Create(); + +public: + void Update(double percentage) override; + void Finish() override; + +private: + void Initialize(); + +private: + SEXP progress_callback = R_NilValue; +}; + +} // namespace duckdb diff --git a/src/include/rapi.hpp b/src/include/rapi.hpp index 50fe9a86b..213990a7c 100644 --- a/src/include/rapi.hpp +++ b/src/include/rapi.hpp @@ -171,6 +171,7 @@ struct RStrings { SEXP ImportRecordBatchReader_sym; SEXP materialize_callback_sym; SEXP materialize_message_sym; + SEXP progress_display_sym; SEXP duckdb_row_names_sym; SEXP duckdb_vector_sym; diff --git a/src/utils.cpp b/src/utils.cpp index 746386566..1f0e3aabf 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -82,6 +82,7 @@ RStrings::RStrings() { Table__from_record_batches_sym = Rf_install("Table__from_record_batches"); materialize_message_sym = Rf_install("duckdb.materialize_message"); materialize_callback_sym = Rf_install("duckdb.materialize_callback"); + progress_display_sym = Rf_install("duckdb.progress_display"); duckdb_row_names_sym = Rf_install("duckdb_row_names"); duckdb_vector_sym = Rf_install("duckdb_vector"); } diff --git a/tests/testthat/test-progress-display.R b/tests/testthat/test-progress-display.R new file mode 100644 index 000000000..7e6858b78 --- /dev/null +++ b/tests/testthat/test-progress-display.R @@ -0,0 +1,11 @@ +test_that("progress display", { + + expect_error(check_progress_display(5), "expecting either a boolean or function") + expect_error(check_progress_display(function(){}), "has no argument, expecting at least one") + + expect_null(check_progress_display(function(x){})) + expect_null(check_progress_display(TRUE)) + expect_null(check_progress_display(FALSE)) + expect_null(check_progress_display(NULL)) + +})