From b93d0c3e048ea2305edc4be0f3373d5b71d87604 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Romain=20Fran=C3=A7ois?= Date: Thu, 3 Feb 2022 09:51:32 +0100 Subject: [PATCH] Fix `slice_sample()` handling of `n=` and `prop=` (#6172) * Fix slice_sample() handling of n= and prop= --- R/slice.R | 23 ++++++++++++----------- tests/testthat/_snaps/slice.md | 31 +++++++++++++++++++++++++++++++ tests/testthat/test-slice.r | 33 +++++++++++++++++++++++++++++++-- 3 files changed, 74 insertions(+), 13 deletions(-) diff --git a/R/slice.R b/R/slice.R index fdb36ce7c8..374aa08614 100644 --- a/R/slice.R +++ b/R/slice.R @@ -233,7 +233,7 @@ slice_sample <- function(.data, ..., n, prop, weight_by = NULL, replace = FALSE) #' @export slice_sample.data.frame <- function(.data, ..., n, prop, weight_by = NULL, replace = FALSE) { check_slice_dots(..., n = n, prop = prop) - size <- get_slice_size(n = n, prop = prop) + size <- get_slice_size(n = n, prop = prop, allow_negative = FALSE) dplyr_local_error_call() slice(.data, local({ @@ -412,28 +412,29 @@ check_slice_n_prop <- function(n, prop, error_call = caller_env()) { } } -get_slice_size <- function(n, prop, error_call = caller_env()) { +get_slice_size <- function(n, prop, allow_negative = TRUE, error_call = caller_env()) { slice_input <- check_slice_n_prop(n, prop, error_call = error_call) if (slice_input$type == "n") { - if (slice_input$n < 0) { - function(n) max(ceiling(n + slice_input$n), 0) + if (slice_input$n > 0) { + function(n) floor(slice_input$n) + } else if (allow_negative) { + function(n) ceiling(n + slice_input$n) } else { - function(n) min(floor(slice_input$n), n) + abort("`n` must be positive.", call = error_call) } } else if (slice_input$type == "prop") { - if (slice_input$prop < 0) { - function(n) max(ceiling(n + slice_input$prop * n), 0) + if (slice_input$prop > 0) { + function(n) floor(slice_input$prop * n) + } else if (allow_negative) { + function(n) ceiling(n + slice_input$prop * n) } else { - function(n) min(floor(slice_input$prop * n), n) + abort("`prop` must be positive.", call = error_call) } } } sample_int <- function(n, size, replace = FALSE, wt = NULL) { - if (!replace) { - size <- min(size, n) - } if (size == 0L) { integer(0) } else { diff --git a/tests/testthat/_snaps/slice.md b/tests/testthat/_snaps/slice.md index c3c36ccb1a..0db915610b 100644 --- a/tests/testthat/_snaps/slice.md +++ b/tests/testthat/_snaps/slice.md @@ -1,3 +1,34 @@ +# slice_sample() handles n= and prop= + + Code + (expect_error(df %>% slice_sample(n = -1))) + Output + + Error in `slice_sample()`: + ! `n` must be positive. + Code + (expect_error(df %>% slice_sample(prop = -1))) + Output + + Error in `slice_sample()`: + ! `prop` must be positive. + Code + (expect_error(df %>% slice_sample(n = 4, replace = FALSE))) + Output + + Error in `slice_sample()`: + ! Problem while computing indices. + Caused by error in `sample.int()`: + ! cannot take a sample larger than the population when 'replace = FALSE' + Code + (expect_error(df %>% slice_sample(prop = 4, replace = FALSE))) + Output + + Error in `slice_sample()`: + ! Problem while computing indices. + Caused by error in `sample.int()`: + ! cannot take a sample larger than the population when 'replace = FALSE' + # slice() gives meaningfull errors Code diff --git a/tests/testthat/test-slice.r b/tests/testthat/test-slice.r index 371a229ab3..4a53dcb87a 100644 --- a/tests/testthat/test-slice.r +++ b/tests/testthat/test-slice.r @@ -189,17 +189,46 @@ test_that("slice() handles matrix and data frame columns (#3630)", { # Slice variants ---------------------------------------------------------- +test_that("slice_sample() handles n= and prop=", { + df <- data.frame(a = 1) + + expect_equal( + df %>% slice_sample(n = 4, replace = TRUE), + df %>% slice(rep(1, 4)) + ) + + expect_equal( + df %>% slice_sample(prop = 4, replace = TRUE), + df %>% slice(rep(1, 4)) + ) + + expect_snapshot({ + (expect_error( + df %>% slice_sample(n = -1) + )) + (expect_error( + df %>% slice_sample(prop = -1) + )) + + (expect_error( + df %>% slice_sample(n = 4, replace = FALSE) + )) + + (expect_error( + df %>% slice_sample(prop = 4, replace = FALSE) + )) + }) +}) + test_that("functions silently truncate results", { df <- data.frame(x = 1:5) expect_equal(df %>% slice_head(n = 6) %>% nrow(), 5) expect_equal(df %>% slice_tail(n = 6) %>% nrow(), 5) - expect_equal(df %>% slice_sample(n = 6) %>% nrow(), 5) expect_equal(df %>% slice_min(x, n = 6) %>% nrow(), 5) expect_equal(df %>% slice_max(x, n = 6) %>% nrow(), 5) expect_equal(df %>% slice_head(n = -6) %>% nrow(), 0) expect_equal(df %>% slice_tail(n = -6) %>% nrow(), 0) - expect_equal(df %>% slice_sample(n = -6) %>% nrow(), 0) expect_equal(df %>% slice_min(x, n = -6) %>% nrow(), 0) expect_equal(df %>% slice_max(x, n = -6) %>% nrow(), 0) })