From 07961a0dc95f7358d7056aafb153274ccd0ac7a2 Mon Sep 17 00:00:00 2001 From: "Simon P. Couch" Date: Wed, 17 Jan 2024 08:32:22 -0600 Subject: [PATCH] fix model fit for spark tbls (#1047) --- NEWS.md | 2 ++ R/arguments.R | 10 +++++++++- tests/testthat/test_boost_tree.R | 7 +++++++ 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/NEWS.md b/NEWS.md index 00bdbfc98..53bd159bf 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,7 @@ # parsnip (development version) +* Fixed bug in fitting some model types with the `"spark"` engine (#1045). + * Fixed issue in `mlp()` metadata where the `stop_iter` engine argument had been mistakenly protected for the `"brulee"` engine. (#1050) * `.filter_eval_time()` was moved to the survival standalone file. diff --git a/R/arguments.R b/R/arguments.R index 2ed1de0af..42721aac4 100644 --- a/R/arguments.R +++ b/R/arguments.R @@ -328,7 +328,11 @@ min_cols <- function(num_cols, source) { #' @export #' @rdname min_cols min_rows <- function(num_rows, source, offset = 0) { - n <- nrow(source) + if (inherits(source, "tbl_spark")) { + n <- nrow_spark(source) + } else { + n <- nrow(source) + } if (num_rows > n - offset) { msg <- paste0(num_rows, " samples were requested but there were ", n, @@ -340,3 +344,7 @@ min_rows <- function(num_rows, source, offset = 0) { as.integer(num_rows) } +nrow_spark <- function(source) { + rlang::check_installed("sparklyr") + sparklyr::sdf_nrow(source) +} diff --git a/tests/testthat/test_boost_tree.R b/tests/testthat/test_boost_tree.R index 7abfcf9a4..f92216870 100644 --- a/tests/testthat/test_boost_tree.R +++ b/tests/testthat/test_boost_tree.R @@ -28,6 +28,9 @@ test_that('bad input', { ## ----------------------------------------------------------------------------- test_that('argument checks for data dimensions', { + skip_if_not_installed("sparklyr") + library(sparklyr) + skip_if(nrow(spark_installed_versions()) == 0) spec <- boost_tree(mtry = 1000, min_n = 1000, trees = 5) %>% @@ -36,6 +39,10 @@ test_that('argument checks for data dimensions', { args <- translate(spec)$method$fit$args expect_equal(args$min_instances_per_node, expr(min_rows(1000, x))) + + sc = spark_connect(master = "local") + cars = copy_to(sc, mtcars, overwrite = TRUE) + expect_equal(min_rows(10, cars), 10) }) test_that('boost_tree can be fit with 1 predictor if validation is used', {