Skip to content

Commit

Permalink
Fix slice_sample() handling of n= and prop= (#6172)
Browse files Browse the repository at this point in the history
* Fix slice_sample() handling of n= and prop=
  • Loading branch information
romainfrancois authored Feb 3, 2022
1 parent 520b36e commit b93d0c3
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 13 deletions.
23 changes: 12 additions & 11 deletions R/slice.R
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ slice_sample <- function(.data, ..., n, prop, weight_by = NULL, replace = FALSE)
#' @export
slice_sample.data.frame <- function(.data, ..., n, prop, weight_by = NULL, replace = FALSE) {
check_slice_dots(..., n = n, prop = prop)
size <- get_slice_size(n = n, prop = prop)
size <- get_slice_size(n = n, prop = prop, allow_negative = FALSE)

dplyr_local_error_call()
slice(.data, local({
Expand Down Expand Up @@ -412,28 +412,29 @@ check_slice_n_prop <- function(n, prop, error_call = caller_env()) {
}
}

get_slice_size <- function(n, prop, error_call = caller_env()) {
get_slice_size <- function(n, prop, allow_negative = TRUE, error_call = caller_env()) {
slice_input <- check_slice_n_prop(n, prop, error_call = error_call)

if (slice_input$type == "n") {
if (slice_input$n < 0) {
function(n) max(ceiling(n + slice_input$n), 0)
if (slice_input$n > 0) {
function(n) floor(slice_input$n)
} else if (allow_negative) {
function(n) ceiling(n + slice_input$n)
} else {
function(n) min(floor(slice_input$n), n)
abort("`n` must be positive.", call = error_call)
}
} else if (slice_input$type == "prop") {
if (slice_input$prop < 0) {
function(n) max(ceiling(n + slice_input$prop * n), 0)
if (slice_input$prop > 0) {
function(n) floor(slice_input$prop * n)
} else if (allow_negative) {
function(n) ceiling(n + slice_input$prop * n)
} else {
function(n) min(floor(slice_input$prop * n), n)
abort("`prop` must be positive.", call = error_call)
}
}
}

sample_int <- function(n, size, replace = FALSE, wt = NULL) {
if (!replace) {
size <- min(size, n)
}
if (size == 0L) {
integer(0)
} else {
Expand Down
31 changes: 31 additions & 0 deletions tests/testthat/_snaps/slice.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,34 @@
# slice_sample() handles n= and prop=

Code
(expect_error(df %>% slice_sample(n = -1)))
Output
<error/rlang_error>
Error in `slice_sample()`:
! `n` must be positive.
Code
(expect_error(df %>% slice_sample(prop = -1)))
Output
<error/rlang_error>
Error in `slice_sample()`:
! `prop` must be positive.
Code
(expect_error(df %>% slice_sample(n = 4, replace = FALSE)))
Output
<error/rlang_error>
Error in `slice_sample()`:
! Problem while computing indices.
Caused by error in `sample.int()`:
! cannot take a sample larger than the population when 'replace = FALSE'
Code
(expect_error(df %>% slice_sample(prop = 4, replace = FALSE)))
Output
<error/rlang_error>
Error in `slice_sample()`:
! Problem while computing indices.
Caused by error in `sample.int()`:
! cannot take a sample larger than the population when 'replace = FALSE'

# slice() gives meaningfull errors

Code
Expand Down
33 changes: 31 additions & 2 deletions tests/testthat/test-slice.r
Original file line number Diff line number Diff line change
Expand Up @@ -189,17 +189,46 @@ test_that("slice() handles matrix and data frame columns (#3630)", {

# Slice variants ----------------------------------------------------------

test_that("slice_sample() handles n= and prop=", {
df <- data.frame(a = 1)

expect_equal(
df %>% slice_sample(n = 4, replace = TRUE),
df %>% slice(rep(1, 4))
)

expect_equal(
df %>% slice_sample(prop = 4, replace = TRUE),
df %>% slice(rep(1, 4))
)

expect_snapshot({
(expect_error(
df %>% slice_sample(n = -1)
))
(expect_error(
df %>% slice_sample(prop = -1)
))

(expect_error(
df %>% slice_sample(n = 4, replace = FALSE)
))

(expect_error(
df %>% slice_sample(prop = 4, replace = FALSE)
))
})
})

test_that("functions silently truncate results", {
df <- data.frame(x = 1:5)

expect_equal(df %>% slice_head(n = 6) %>% nrow(), 5)
expect_equal(df %>% slice_tail(n = 6) %>% nrow(), 5)
expect_equal(df %>% slice_sample(n = 6) %>% nrow(), 5)
expect_equal(df %>% slice_min(x, n = 6) %>% nrow(), 5)
expect_equal(df %>% slice_max(x, n = 6) %>% nrow(), 5)
expect_equal(df %>% slice_head(n = -6) %>% nrow(), 0)
expect_equal(df %>% slice_tail(n = -6) %>% nrow(), 0)
expect_equal(df %>% slice_sample(n = -6) %>% nrow(), 0)
expect_equal(df %>% slice_min(x, n = -6) %>% nrow(), 0)
expect_equal(df %>% slice_max(x, n = -6) %>% nrow(), 0)
})
Expand Down

0 comments on commit b93d0c3

Please sign in to comment.