From ad9820378bff891fc93796e54234696c0ce133c9 Mon Sep 17 00:00:00 2001 From: eitsupi <50911393+eitsupi@users.noreply.github.com> Date: Fri, 16 Sep 2022 10:42:34 +0900 Subject: [PATCH] ARROW-17689: [R] Implement dplyr::across() inside group_by() (#14122) Because the handling of the case `.add = TRUE` and the `add` argument have been changed, test cases for these are also added. Authored-by: SHIMA Tatsuya Signed-off-by: Dewey Dunnington --- r/R/dplyr-group-by.R | 38 ++++----- r/tests/testthat/test-dplyr-group-by.R | 110 +++++++++++++++++++++++++ 2 files changed, 126 insertions(+), 22 deletions(-) diff --git a/r/R/dplyr-group-by.R b/r/R/dplyr-group-by.R index c650799e8d06c..57cf417c9ad2b 100644 --- a/r/R/dplyr-group-by.R +++ b/r/R/dplyr-group-by.R @@ -21,37 +21,31 @@ group_by.arrow_dplyr_query <- function(.data, ..., .add = FALSE, - add = .add, + add = NULL, .drop = dplyr::group_by_drop_default(.data)) { + if (!missing(add)) { + .Deprecated( + msg = paste("The `add` argument of `group_by()` is deprecated. Please use the `.add` argument instead.") + ) + .add <- add + } + .data <- as_adq(.data) - new_groups <- enquos(...) - # ... can contain expressions (i.e. can add (or rename?) columns) and so we - # need to identify those and add them on to the query with mutate. Specifically, - # we want to mark as new: - # * expressions (named or otherwise) - # * variables that have new names - # All others (i.e. simple references to variables) should not be (re)-added + expression_list <- expand_across(.data, quos(...)) + new_groups <- ensure_named_exprs(expression_list) - # Identify any groups with names which aren't in names of .data - new_group_ind <- map_lgl(new_groups, ~ !(quo_name(.x) %in% names(.data))) - # Identify any groups which don't have names - named_group_ind <- map_lgl(names(new_groups), nzchar) - # Retain any new groups identified above - new_groups <- new_groups[new_group_ind | named_group_ind] if (length(new_groups)) { - # now either use the name that was given in ... or if that is "" then use the expr - names(new_groups) <- imap_chr(new_groups, ~ ifelse(.y == "", quo_name(.x), .y)) - # Add them to the data .data <- dplyr::mutate(.data, !!!new_groups) } - if (".add" %in% names(formals(dplyr::group_by))) { - # For compatibility with dplyr >= 1.0 - gv <- dplyr::group_by_prepare(.data, ..., .add = .add)$group_names + + if (.add) { + gv <- union(dplyr::group_vars(.data), names(new_groups)) } else { - gv <- dplyr::group_by_prepare(.data, ..., add = add)$group_names + gv <- names(new_groups) } - .data$group_by_vars <- gv + + .data$group_by_vars <- gv %||% character() .data$drop_empty_groups <- ifelse(length(gv), .drop, dplyr::group_by_drop_default(.data)) .data } diff --git a/r/tests/testthat/test-dplyr-group-by.R b/r/tests/testthat/test-dplyr-group-by.R index c7380e96ec302..9bb6aa9600dbd 100644 --- a/r/tests/testthat/test-dplyr-group-by.R +++ b/r/tests/testthat/test-dplyr-group-by.R @@ -166,3 +166,113 @@ test_that("group_by() with namespaced functions", { tbl ) }) + +test_that("group_by() with .add", { + compare_dplyr_binding( + .input %>% + group_by(dbl2) %>% + group_by(.add = FALSE) %>% + collect(), + tbl + ) + compare_dplyr_binding( + .input %>% + group_by(dbl2) %>% + group_by(.add = TRUE) %>% + collect(), + tbl + ) + compare_dplyr_binding( + .input %>% + group_by(dbl2) %>% + group_by(chr, .add = FALSE) %>% + collect(), + tbl + ) + compare_dplyr_binding( + .input %>% + group_by(dbl2) %>% + group_by(chr, .add = TRUE) %>% + collect(), + tbl + ) + compare_dplyr_binding( + .input %>% + group_by(chr, .add = FALSE) %>% + collect(), + tbl %>% + group_by(dbl2) + ) + compare_dplyr_binding( + .input %>% + group_by(chr, .add = TRUE) %>% + collect(), + tbl %>% + group_by(dbl2) + ) + suppressWarnings(compare_dplyr_binding( + .input %>% + group_by(dbl2) %>% + group_by(add = FALSE) %>% + collect(), + tbl, + warning = "deprecated" + )) + suppressWarnings(compare_dplyr_binding( + .input %>% + group_by(dbl2) %>% + group_by(add = TRUE) %>% + collect(), + tbl, + warning = "deprecated" + )) + expect_warning( + tbl %>% + arrow_table() %>% + group_by(add = TRUE) %>% + collect(), + "The `add` argument of `group_by\\(\\)` is deprecated" + ) + expect_error( + suppressWarnings( + tbl %>% + arrow_table() %>% + group_by(add = dbl2) %>% + collect() + ), + "object 'dbl2' not found" + ) +}) + +test_that("Can use across() within group_by()", { + test_groups <- c("dbl", "int", "chr") + compare_dplyr_binding( + .input %>% + group_by(across()) %>% + collect(), + tbl + ) + compare_dplyr_binding( + .input %>% + group_by(across(starts_with("d"))) %>% + collect(), + tbl + ) + compare_dplyr_binding( + .input %>% + group_by(across({{ test_groups }})) %>% + collect(), + tbl + ) + + # ARROW-12778 - `where()` is not yet supported + expect_error( + compare_dplyr_binding( + .input %>% + group_by(across(where(is.numeric))) %>% + collect(), + tbl + ), + "Unsupported selection helper" + ) +})