From 1b3caf6b232b7855956d3ec45ee95ede0492e78f Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Sat, 7 Dec 2024 10:04:08 -0500 Subject: [PATCH] GH-29642: [R] Support for .keep_all = TRUE with distinct() (#44652) ### Rationale for this change Support a missing feature, just wiring up some stuff from R to Acero, then adding docs and tests. This is mostly picking up where #13934 started and finishing it out. Thanks @mopcup for the initial lift. ### What changes are included in this PR? An aggregation binding, some symbol manipulation, and tests. I also cleaned up some dplyr test shims from 2022. ### Are these changes tested? Yes, though if anyone knows of odd corners in `distinct()` that aren't covered by this, we can add more ### Are there any user-facing changes? Yes indeed. * GitHub Issue: #29642 --- r/R/arrow-package.R | 5 +- r/R/dplyr-distinct.R | 25 ++++++-- r/R/dplyr-funcs-agg.R | 7 ++ r/tests/testthat/test-dplyr-distinct.R | 89 ++++++++++++++++++-------- 4 files changed, 90 insertions(+), 36 deletions(-) diff --git a/r/R/arrow-package.R b/r/R/arrow-package.R index 4c3b78e085c6e..4b54697d4bd90 100644 --- a/r/R/arrow-package.R +++ b/r/R/arrow-package.R @@ -62,7 +62,10 @@ supported_dplyr_methods <- list( relocate = NULL, compute = NULL, collapse = NULL, - distinct = "`.keep_all = TRUE` not supported", + distinct = c( + "`.keep_all = TRUE` returns a non-missing value if present,", + "only returning missing values if all are missing." + ), left_join = "the `copy` argument is ignored", right_join = "the `copy` argument is ignored", inner_join = "the `copy` argument is ignored", diff --git a/r/R/dplyr-distinct.R b/r/R/dplyr-distinct.R index 49948caa011e2..95fb837bd5d00 100644 --- a/r/R/dplyr-distinct.R +++ b/r/R/dplyr-distinct.R @@ -18,12 +18,6 @@ # The following S3 methods are registered on load if dplyr is present distinct.arrow_dplyr_query <- function(.data, ..., .keep_all = FALSE) { - if (.keep_all == TRUE) { - # TODO(ARROW-14045): the function is called "hash_one" (from ARROW-13993) - # May need to call it: `summarize(x = one(x), ...)` for x in non-group cols - arrow_not_supported("`distinct()` with `.keep_all = TRUE`") - } - original_gv <- dplyr::group_vars(.data) if (length(quos(...))) { # group_by() calls mutate() if there are any expressions in ... @@ -33,11 +27,28 @@ distinct.arrow_dplyr_query <- function(.data, ..., .keep_all = FALSE) { .data <- dplyr::group_by(.data, !!!syms(names(.data))) } - out <- dplyr::summarize(.data, .groups = "drop") + if (isTRUE(.keep_all)) { + # Note: in regular dplyr, `.keep_all = TRUE` returns the first row's value. + # However, Acero's `hash_one` function prefers returning non-null values. + # So, you'll get the same shape of data, but the values may differ. + keeps <- names(.data)[!(names(.data) %in% .data$group_by_vars)] + exprs <- lapply(keeps, function(x) call2("one", sym(x))) + names(exprs) <- keeps + } else { + exprs <- list() + } + + out <- dplyr::summarize(.data, !!!exprs, .groups = "drop") + # distinct() doesn't modify group by vars, so restore the original ones if (length(original_gv)) { out$group_by_vars <- original_gv } + if (isTRUE(.keep_all)) { + # Also ensure the column order matches the original + # summarize() will put the group_by_vars first + out <- dplyr::select(out, !!!syms(names(.data))) + } out } diff --git a/r/R/dplyr-funcs-agg.R b/r/R/dplyr-funcs-agg.R index 340ebe7adc90f..275fca36542bf 100644 --- a/r/R/dplyr-funcs-agg.R +++ b/r/R/dplyr-funcs-agg.R @@ -150,6 +150,13 @@ register_bindings_aggregate <- function() { options = list(skip_nulls = na.rm, min_count = 0L) ) }) + register_binding("arrow::one", function(...) { + set_agg( + fun = "one", + data = ensure_one_arg(list2(...), "one"), + options = list() + ) + }) } set_agg <- function(...) { diff --git a/r/tests/testthat/test-dplyr-distinct.R b/r/tests/testthat/test-dplyr-distinct.R index 4c7f8894cd4e4..e4d789e8e9146 100644 --- a/r/tests/testthat/test-dplyr-distinct.R +++ b/r/tests/testthat/test-dplyr-distinct.R @@ -26,11 +26,8 @@ test_that("distinct()", { compare_dplyr_binding( .input %>% distinct(some_grouping, lgl) %>% - collect() %>% - # GH-14947: column output order changed in dplyr 1.1.0, so we need - # to make the column order explicit until dplyr 1.1.0 is on CRAN - select(some_grouping, lgl) %>% - arrange(some_grouping, lgl), + arrange(some_grouping, lgl) %>% + collect(), tbl ) }) @@ -60,11 +57,8 @@ test_that("distinct() can retain groups", { .input %>% group_by(some_grouping, int) %>% distinct(lgl) %>% - collect() %>% - # GH-14947: column output order changed in dplyr 1.1.0, so we need - # to make the column order explicit until dplyr 1.1.0 is on CRAN - select(some_grouping, int, lgl) %>% - arrange(lgl, int), + arrange(lgl, int) %>% + collect(), tbl ) @@ -73,11 +67,8 @@ test_that("distinct() can retain groups", { .input %>% group_by(y = some_grouping, int) %>% distinct(x = lgl) %>% - collect() %>% - # GH-14947: column output order changed in dplyr 1.1.0, so we need - # to make the column order explicit until dplyr 1.1.0 is on CRAN - select(y, int, x) %>% - arrange(int), + arrange(int) %>% + collect(), tbl ) }) @@ -95,11 +86,8 @@ test_that("distinct() can contain expressions", { .input %>% group_by(lgl, int) %>% distinct(x = some_grouping + 1) %>% - collect() %>% - # GH-14947: column output order changed in dplyr 1.1.0, so we need - # to make the column order explicit until dplyr 1.1.0 is on CRAN - select(lgl, int, x) %>% - arrange(int), + arrange(int) %>% + collect(), tbl ) }) @@ -115,12 +103,57 @@ test_that("across() works in distinct()", { }) test_that("distinct() can return all columns", { - skip("ARROW-14045") - compare_dplyr_binding( - .input %>% - distinct(lgl, .keep_all = TRUE) %>% - collect() %>% - arrange(int), - tbl - ) + # hash_one prefers to keep non-null values, which is different from .keep_all in dplyr + # so we can't compare the result directly + expected <- tbl %>% + # Drop factor because of #44661: + # NotImplemented: Function 'hash_one' has no kernel matching input types + # (dictionary, uint8) + select(-fct) %>% + distinct(lgl, .keep_all = TRUE) %>% + arrange(int) + + with_table <- tbl %>% + arrow_table() %>% + select(-fct) %>% + distinct(lgl, .keep_all = TRUE) %>% + arrange(int) %>% + collect() + + expect_identical(dim(with_table), dim(expected)) + expect_identical(names(with_table), names(expected)) + + # Test with some mutation in there + expected <- tbl %>% + select(-fct) %>% + distinct(lgl, bigger = int * 10L, .keep_all = TRUE) %>% + arrange(int) + + with_table <- tbl %>% + arrow_table() %>% + select(-fct) %>% + distinct(lgl, bigger = int * 10, .keep_all = TRUE) %>% + arrange(int) %>% + collect() + + expect_identical(dim(with_table), dim(expected)) + expect_identical(names(with_table), names(expected)) + expect_identical(with_table$bigger, expected$bigger) + + # Mutation that overwrites + expected <- tbl %>% + select(-fct) %>% + distinct(lgl, int = int * 10L, .keep_all = TRUE) %>% + arrange(int) + + with_table <- tbl %>% + arrow_table() %>% + select(-fct) %>% + distinct(lgl, int = int * 10, .keep_all = TRUE) %>% + arrange(int) %>% + collect() + + expect_identical(dim(with_table), dim(expected)) + expect_identical(names(with_table), names(expected)) + expect_identical(with_table$int, expected$int) })