Skip to content

Commit

Permalink
GH-29642: [R] Support for .keep_all = TRUE with distinct() (#44652)
Browse files Browse the repository at this point in the history
### Rationale for this change

Support a missing feature, just wiring up some stuff from R to Acero,
then adding docs and tests.

This is mostly picking up where #13934 started and finishing it out.
Thanks @mopcup for the initial lift.

### What changes are included in this PR?

An aggregation binding, some symbol manipulation, and tests. I also
cleaned up some dplyr test shims from 2022.

### Are these changes tested?

Yes, though if anyone knows of odd corners in `distinct()` that aren't
covered by this, we can add more

### Are there any user-facing changes?

Yes indeed.
* GitHub Issue: #29642
  • Loading branch information
nealrichardson authored Dec 7, 2024
1 parent cf7ab12 commit 1b3caf6
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 36 deletions.
5 changes: 4 additions & 1 deletion r/R/arrow-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ supported_dplyr_methods <- list(
relocate = NULL,
compute = NULL,
collapse = NULL,
distinct = "`.keep_all = TRUE` not supported",
distinct = c(
"`.keep_all = TRUE` returns a non-missing value if present,",
"only returning missing values if all are missing."
),
left_join = "the `copy` argument is ignored",
right_join = "the `copy` argument is ignored",
inner_join = "the `copy` argument is ignored",
Expand Down
25 changes: 18 additions & 7 deletions r/R/dplyr-distinct.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,6 @@
# The following S3 methods are registered on load if dplyr is present

distinct.arrow_dplyr_query <- function(.data, ..., .keep_all = FALSE) {
if (.keep_all == TRUE) {
# TODO(ARROW-14045): the function is called "hash_one" (from ARROW-13993)
# May need to call it: `summarize(x = one(x), ...)` for x in non-group cols
arrow_not_supported("`distinct()` with `.keep_all = TRUE`")
}

original_gv <- dplyr::group_vars(.data)
if (length(quos(...))) {
# group_by() calls mutate() if there are any expressions in ...
Expand All @@ -33,11 +27,28 @@ distinct.arrow_dplyr_query <- function(.data, ..., .keep_all = FALSE) {
.data <- dplyr::group_by(.data, !!!syms(names(.data)))
}

out <- dplyr::summarize(.data, .groups = "drop")
if (isTRUE(.keep_all)) {
# Note: in regular dplyr, `.keep_all = TRUE` returns the first row's value.
# However, Acero's `hash_one` function prefers returning non-null values.
# So, you'll get the same shape of data, but the values may differ.
keeps <- names(.data)[!(names(.data) %in% .data$group_by_vars)]
exprs <- lapply(keeps, function(x) call2("one", sym(x)))
names(exprs) <- keeps
} else {
exprs <- list()
}

out <- dplyr::summarize(.data, !!!exprs, .groups = "drop")

# distinct() doesn't modify group by vars, so restore the original ones
if (length(original_gv)) {
out$group_by_vars <- original_gv
}
if (isTRUE(.keep_all)) {
# Also ensure the column order matches the original
# summarize() will put the group_by_vars first
out <- dplyr::select(out, !!!syms(names(.data)))
}
out
}

Expand Down
7 changes: 7 additions & 0 deletions r/R/dplyr-funcs-agg.R
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,13 @@ register_bindings_aggregate <- function() {
options = list(skip_nulls = na.rm, min_count = 0L)
)
})
register_binding("arrow::one", function(...) {
set_agg(
fun = "one",
data = ensure_one_arg(list2(...), "one"),
options = list()
)
})
}

set_agg <- function(...) {
Expand Down
89 changes: 61 additions & 28 deletions r/tests/testthat/test-dplyr-distinct.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,8 @@ test_that("distinct()", {
compare_dplyr_binding(
.input %>%
distinct(some_grouping, lgl) %>%
collect() %>%
# GH-14947: column output order changed in dplyr 1.1.0, so we need
# to make the column order explicit until dplyr 1.1.0 is on CRAN
select(some_grouping, lgl) %>%
arrange(some_grouping, lgl),
arrange(some_grouping, lgl) %>%
collect(),
tbl
)
})
Expand Down Expand Up @@ -60,11 +57,8 @@ test_that("distinct() can retain groups", {
.input %>%
group_by(some_grouping, int) %>%
distinct(lgl) %>%
collect() %>%
# GH-14947: column output order changed in dplyr 1.1.0, so we need
# to make the column order explicit until dplyr 1.1.0 is on CRAN
select(some_grouping, int, lgl) %>%
arrange(lgl, int),
arrange(lgl, int) %>%
collect(),
tbl
)

Expand All @@ -73,11 +67,8 @@ test_that("distinct() can retain groups", {
.input %>%
group_by(y = some_grouping, int) %>%
distinct(x = lgl) %>%
collect() %>%
# GH-14947: column output order changed in dplyr 1.1.0, so we need
# to make the column order explicit until dplyr 1.1.0 is on CRAN
select(y, int, x) %>%
arrange(int),
arrange(int) %>%
collect(),
tbl
)
})
Expand All @@ -95,11 +86,8 @@ test_that("distinct() can contain expressions", {
.input %>%
group_by(lgl, int) %>%
distinct(x = some_grouping + 1) %>%
collect() %>%
# GH-14947: column output order changed in dplyr 1.1.0, so we need
# to make the column order explicit until dplyr 1.1.0 is on CRAN
select(lgl, int, x) %>%
arrange(int),
arrange(int) %>%
collect(),
tbl
)
})
Expand All @@ -115,12 +103,57 @@ test_that("across() works in distinct()", {
})

test_that("distinct() can return all columns", {
skip("ARROW-14045")
compare_dplyr_binding(
.input %>%
distinct(lgl, .keep_all = TRUE) %>%
collect() %>%
arrange(int),
tbl
)
# hash_one prefers to keep non-null values, which is different from .keep_all in dplyr
# so we can't compare the result directly
expected <- tbl %>%
# Drop factor because of #44661:
# NotImplemented: Function 'hash_one' has no kernel matching input types
# (dictionary<values=string, indices=int8, ordered=0>, uint8)
select(-fct) %>%
distinct(lgl, .keep_all = TRUE) %>%
arrange(int)

with_table <- tbl %>%
arrow_table() %>%
select(-fct) %>%
distinct(lgl, .keep_all = TRUE) %>%
arrange(int) %>%
collect()

expect_identical(dim(with_table), dim(expected))
expect_identical(names(with_table), names(expected))

# Test with some mutation in there
expected <- tbl %>%
select(-fct) %>%
distinct(lgl, bigger = int * 10L, .keep_all = TRUE) %>%
arrange(int)

with_table <- tbl %>%
arrow_table() %>%
select(-fct) %>%
distinct(lgl, bigger = int * 10, .keep_all = TRUE) %>%
arrange(int) %>%
collect()

expect_identical(dim(with_table), dim(expected))
expect_identical(names(with_table), names(expected))
expect_identical(with_table$bigger, expected$bigger)

# Mutation that overwrites
expected <- tbl %>%
select(-fct) %>%
distinct(lgl, int = int * 10L, .keep_all = TRUE) %>%
arrange(int)

with_table <- tbl %>%
arrow_table() %>%
select(-fct) %>%
distinct(lgl, int = int * 10, .keep_all = TRUE) %>%
arrange(int) %>%
collect()

expect_identical(dim(with_table), dim(expected))
expect_identical(names(with_table), names(expected))
expect_identical(with_table$int, expected$int)
})

0 comments on commit 1b3caf6

Please sign in to comment.