Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-29642: [R] Support for .keep_all = TRUE with distinct() #44652

Merged
merged 6 commits into from
Dec 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion r/R/arrow-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ supported_dplyr_methods <- list(
relocate = NULL,
compute = NULL,
collapse = NULL,
distinct = "`.keep_all = TRUE` not supported",
distinct = c(
"`.keep_all = TRUE` returns a non-missing value if present,",
"only returning missing values if all are missing."
),
left_join = "the `copy` argument is ignored",
right_join = "the `copy` argument is ignored",
inner_join = "the `copy` argument is ignored",
Expand Down
25 changes: 18 additions & 7 deletions r/R/dplyr-distinct.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,6 @@
# The following S3 methods are registered on load if dplyr is present

distinct.arrow_dplyr_query <- function(.data, ..., .keep_all = FALSE) {
if (.keep_all == TRUE) {
# TODO(ARROW-14045): the function is called "hash_one" (from ARROW-13993)
# May need to call it: `summarize(x = one(x), ...)` for x in non-group cols
arrow_not_supported("`distinct()` with `.keep_all = TRUE`")
}

original_gv <- dplyr::group_vars(.data)
if (length(quos(...))) {
# group_by() calls mutate() if there are any expressions in ...
Expand All @@ -33,11 +27,28 @@ distinct.arrow_dplyr_query <- function(.data, ..., .keep_all = FALSE) {
.data <- dplyr::group_by(.data, !!!syms(names(.data)))
}

out <- dplyr::summarize(.data, .groups = "drop")
if (isTRUE(.keep_all)) {
# Note: in regular dplyr, `.keep_all = TRUE` returns the first row's value.
# However, Acero's `hash_one` function prefers returning non-null values.
# So, you'll get the same shape of data, but the values may differ.
Comment on lines +31 to +33
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This behavior change is probably either not-impactful, or if folks are relying on it, that is actually a bug in their code. Though it does seem like something we should mention (in docs at least?).

Or maybe with a one-time warning?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is documented on the acero man page, that's the change to arrow-package.R. I'd rather not one-time warning; that's a slippery slope if we were going to be chatty about every subtle difference between how Acero works from dplyr on data.frames.

keeps <- names(.data)[!(names(.data) %in% .data$group_by_vars)]
exprs <- lapply(keeps, function(x) call2("one", sym(x)))
names(exprs) <- keeps
} else {
exprs <- list()
}

out <- dplyr::summarize(.data, !!!exprs, .groups = "drop")

# distinct() doesn't modify group by vars, so restore the original ones
if (length(original_gv)) {
out$group_by_vars <- original_gv
}
if (isTRUE(.keep_all)) {
# Also ensure the column order matches the original
# summarize() will put the group_by_vars first
out <- dplyr::select(out, !!!syms(names(.data)))
}
out
}

Expand Down
7 changes: 7 additions & 0 deletions r/R/dplyr-funcs-agg.R
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,13 @@ register_bindings_aggregate <- function() {
options = list(skip_nulls = na.rm, min_count = 0L)
)
})
register_binding("arrow::one", function(...) {
set_agg(
fun = "one",
data = ensure_one_arg(list2(...), "one"),
options = list()
)
})
}

set_agg <- function(...) {
Expand Down
89 changes: 61 additions & 28 deletions r/tests/testthat/test-dplyr-distinct.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,8 @@ test_that("distinct()", {
compare_dplyr_binding(
.input %>%
distinct(some_grouping, lgl) %>%
collect() %>%
# GH-14947: column output order changed in dplyr 1.1.0, so we need
# to make the column order explicit until dplyr 1.1.0 is on CRAN
select(some_grouping, lgl) %>%
arrange(some_grouping, lgl),
arrange(some_grouping, lgl) %>%
collect(),
tbl
)
})
Expand Down Expand Up @@ -60,11 +57,8 @@ test_that("distinct() can retain groups", {
.input %>%
group_by(some_grouping, int) %>%
distinct(lgl) %>%
collect() %>%
# GH-14947: column output order changed in dplyr 1.1.0, so we need
# to make the column order explicit until dplyr 1.1.0 is on CRAN
select(some_grouping, int, lgl) %>%
arrange(lgl, int),
arrange(lgl, int) %>%
collect(),
tbl
)

Expand All @@ -73,11 +67,8 @@ test_that("distinct() can retain groups", {
.input %>%
group_by(y = some_grouping, int) %>%
distinct(x = lgl) %>%
collect() %>%
# GH-14947: column output order changed in dplyr 1.1.0, so we need
# to make the column order explicit until dplyr 1.1.0 is on CRAN
select(y, int, x) %>%
arrange(int),
arrange(int) %>%
collect(),
tbl
)
})
Expand All @@ -95,11 +86,8 @@ test_that("distinct() can contain expressions", {
.input %>%
group_by(lgl, int) %>%
distinct(x = some_grouping + 1) %>%
collect() %>%
# GH-14947: column output order changed in dplyr 1.1.0, so we need
# to make the column order explicit until dplyr 1.1.0 is on CRAN
select(lgl, int, x) %>%
arrange(int),
arrange(int) %>%
collect(),
tbl
)
})
Expand All @@ -115,12 +103,57 @@ test_that("across() works in distinct()", {
})

test_that("distinct() can return all columns", {
skip("ARROW-14045")
compare_dplyr_binding(
.input %>%
distinct(lgl, .keep_all = TRUE) %>%
collect() %>%
arrange(int),
tbl
)
# hash_one prefers to keep non-null values, which is different from .keep_all in dplyr
# so we can't compare the result directly
expected <- tbl %>%
# Drop factor because of #44661:
# NotImplemented: Function 'hash_one' has no kernel matching input types
# (dictionary<values=string, indices=int8, ordered=0>, uint8)
Comment on lines +109 to +111
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is 110-111 the error that someone would get if they tried distinct(..., .keep_all = TRUE) with a factor in the table/data.frame?

We might want to make that a bit nicer / more grokable for folks who might not have the dictionary -> factor knowledge top of mind

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's the error message. I'd have to think about how/where best to catch that and translate that to R-speak. As it turns out, dictionary isn't the only unsupported type, it's just the only one we have in this test data frame. I think list types and other non-simple types are also not supported, IIRC from RTFS.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

select(-fct) %>%
distinct(lgl, .keep_all = TRUE) %>%
arrange(int)

with_table <- tbl %>%
arrow_table() %>%
select(-fct) %>%
distinct(lgl, .keep_all = TRUE) %>%
arrange(int) %>%
collect()

expect_identical(dim(with_table), dim(expected))
expect_identical(names(with_table), names(expected))

# Test with some mutation in there
expected <- tbl %>%
select(-fct) %>%
distinct(lgl, bigger = int * 10L, .keep_all = TRUE) %>%
arrange(int)

with_table <- tbl %>%
arrow_table() %>%
select(-fct) %>%
distinct(lgl, bigger = int * 10, .keep_all = TRUE) %>%
arrange(int) %>%
collect()

expect_identical(dim(with_table), dim(expected))
expect_identical(names(with_table), names(expected))
expect_identical(with_table$bigger, expected$bigger)

# Mutation that overwrites
expected <- tbl %>%
select(-fct) %>%
distinct(lgl, int = int * 10L, .keep_all = TRUE) %>%
arrange(int)

with_table <- tbl %>%
arrow_table() %>%
select(-fct) %>%
distinct(lgl, int = int * 10, .keep_all = TRUE) %>%
arrange(int) %>%
collect()

expect_identical(dim(with_table), dim(expected))
expect_identical(names(with_table), names(expected))
expect_identical(with_table$int, expected$int)
})
Loading