Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ARROW-17689: [R] Implement dplyr::across() inside group_by() #14122

Merged
merged 10 commits into from
Sep 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 16 additions & 22 deletions r/R/dplyr-group-by.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,37 +21,31 @@
group_by.arrow_dplyr_query <- function(.data,
...,
.add = FALSE,
add = .add,
add = NULL,
.drop = dplyr::group_by_drop_default(.data)) {
if (!missing(add)) {
.Deprecated(
msg = paste("The `add` argument of `group_by()` is deprecated. Please use the `.add` argument instead.")
)
.add <- add
}

.data <- as_adq(.data)
new_groups <- enquos(...)
# ... can contain expressions (i.e. can add (or rename?) columns) and so we
# need to identify those and add them on to the query with mutate. Specifically,
# we want to mark as new:
# * expressions (named or otherwise)
# * variables that have new names
# All others (i.e. simple references to variables) should not be (re)-added
expression_list <- expand_across(.data, quos(...))
new_groups <- ensure_named_exprs(expression_list)

# Identify any groups with names which aren't in names of .data
new_group_ind <- map_lgl(new_groups, ~ !(quo_name(.x) %in% names(.data)))
# Identify any groups which don't have names
named_group_ind <- map_lgl(names(new_groups), nzchar)
# Retain any new groups identified above
new_groups <- new_groups[new_group_ind | named_group_ind]
if (length(new_groups)) {
# now either use the name that was given in ... or if that is "" then use the expr
names(new_groups) <- imap_chr(new_groups, ~ ifelse(.y == "", quo_name(.x), .y))

# Add them to the data
.data <- dplyr::mutate(.data, !!!new_groups)
}
if (".add" %in% names(formals(dplyr::group_by))) {
# For compatibility with dplyr >= 1.0
gv <- dplyr::group_by_prepare(.data, ..., .add = .add)$group_names

if (.add) {
gv <- union(dplyr::group_vars(.data), names(new_groups))
} else {
gv <- dplyr::group_by_prepare(.data, ..., add = add)$group_names
gv <- names(new_groups)
}
.data$group_by_vars <- gv

.data$group_by_vars <- gv %||% character()
.data$drop_empty_groups <- ifelse(length(gv), .drop, dplyr::group_by_drop_default(.data))
.data
}
Expand Down
110 changes: 110 additions & 0 deletions r/tests/testthat/test-dplyr-group-by.R
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,113 @@ test_that("group_by() with namespaced functions", {
tbl
)
})

test_that("group_by() with .add", {
eitsupi marked this conversation as resolved.
Show resolved Hide resolved
compare_dplyr_binding(
.input %>%
group_by(dbl2) %>%
group_by(.add = FALSE) %>%
collect(),
tbl
)
compare_dplyr_binding(
.input %>%
group_by(dbl2) %>%
group_by(.add = TRUE) %>%
collect(),
tbl
)
compare_dplyr_binding(
.input %>%
group_by(dbl2) %>%
group_by(chr, .add = FALSE) %>%
collect(),
tbl
)
compare_dplyr_binding(
.input %>%
group_by(dbl2) %>%
group_by(chr, .add = TRUE) %>%
collect(),
tbl
)
compare_dplyr_binding(
.input %>%
group_by(chr, .add = FALSE) %>%
collect(),
tbl %>%
group_by(dbl2)
)
compare_dplyr_binding(
.input %>%
group_by(chr, .add = TRUE) %>%
collect(),
tbl %>%
group_by(dbl2)
)
suppressWarnings(compare_dplyr_binding(
.input %>%
group_by(dbl2) %>%
group_by(add = FALSE) %>%
collect(),
tbl,
warning = "deprecated"
))
suppressWarnings(compare_dplyr_binding(
.input %>%
group_by(dbl2) %>%
group_by(add = TRUE) %>%
collect(),
tbl,
warning = "deprecated"
))
expect_warning(
tbl %>%
arrow_table() %>%
group_by(add = TRUE) %>%
collect(),
"The `add` argument of `group_by\\(\\)` is deprecated"
)
expect_error(
suppressWarnings(
tbl %>%
arrow_table() %>%
group_by(add = dbl2) %>%
collect()
),
"object 'dbl2' not found"
)
})

test_that("Can use across() within group_by()", {
test_groups <- c("dbl", "int", "chr")
compare_dplyr_binding(
.input %>%
group_by(across()) %>%
collect(),
tbl
)
compare_dplyr_binding(
.input %>%
group_by(across(starts_with("d"))) %>%
collect(),
tbl
)
compare_dplyr_binding(
.input %>%
group_by(across({{ test_groups }})) %>%
collect(),
tbl
)

# ARROW-12778 - `where()` is not yet supported
expect_error(
compare_dplyr_binding(
.input %>%
group_by(across(where(is.numeric))) %>%
collect(),
tbl
),
"Unsupported selection helper"
)
})