Skip to content

Commit

Permalink
feat: Add across()
Browse files Browse the repository at this point in the history
Closes #37
  • Loading branch information
nathaneastwood committed Oct 23, 2020
1 parent 29d08e1 commit f890578
Show file tree
Hide file tree
Showing 12 changed files with 358 additions and 16 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: poorman
Type: Package
Title: A Poor Man's Dependency Free Recreation of 'dplyr'
Version: 0.2.2.3
Version: 0.2.2.4
Authors@R: person("Nathan", "Eastwood", "", "[email protected]",
role = c("aut", "cre"))
Maintainer: Nathan Eastwood <[email protected]>
Expand Down
6 changes: 6 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ S3method(summarize,grouped_data)
S3method(transmute,default)
S3method(transmute,grouped_data)
export("%>%")
export(across)
export(add_count)
export(add_tally)
export(all_of)
Expand All @@ -50,6 +51,11 @@ export(coalesce)
export(contains)
export(count)
export(cume_dist)
export(cur_data)
export(cur_data_all)
export(cur_group)
export(cur_group_id)
export(cur_group_rows)
export(dense_rank)
export(desc)
export(distinct)
Expand Down
110 changes: 110 additions & 0 deletions R/across.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
#' Apply a function (or functions) across multiple columns
#'
#' @description
#' `across()` makes it easy to apply the same transformation to multiple columns, allowing you to use [select()]
#' semantics inside in "data-masking" functions like [summarise()] and [mutate()].
#'
#' `across()` supersedes the family of {dplyr} "scoped variants" like `summarise_at()`, `summarise_if()`, and
#' `summarise_all()` and therefore these functions will not be implemented in {poorman}.
#'
#' @param cols,.cols <[`poor-select`][select_helpers]> Columns to transform. Because `across()` is used within functions
#' like `summarise()` and `mutate()`, you can't select or compute upon grouping variables.
#' @param .fns Functions to apply to each of the selected columns.
#' Possible values are:
#'
#' - `NULL`, to returns the columns untransformed.
#' - A function, e.g. `mean`.
#' - A list of functions, e.g. `list(mean = mean, sum = sum)`
#'
#' Within these functions you can use [cur_column()] and [cur_group()] to access the current column and grouping keys
#' respectively.
#' @param ... Additional arguments for the function calls in `.fns`.
#' @param .names `character(n)`. Currently limited to specifying a vector of names to use for the outputs.
#'
#' @return
#' A `data.frame` with one column for each column in `.cols` and each function in `.fns`.
#'
#' @examples
#' # across() -----------------------------------------------------------------
#' iris %>%
#' group_by(Species) %>%
#' summarise(across(starts_with("Sepal"), mean))
#' iris %>%
#' mutate(across(where(is.factor), as.character))
#'
#' # Additional parameters can be passed to functions
#' iris %>%
#' group_by(Species) %>%
#' summarise(across(starts_with("Sepal"), mean, na.rm = TRUE))
#'
#' # A named list of functions
#' iris %>%
#' group_by(Species) %>%
#' summarise(across(starts_with("Sepal"), list(mean = mean, sd = sd)))
#'
#' # Use the .names argument to control the output names
#' iris %>%
#' group_by(Species) %>%
#' summarise(across(starts_with("Sepal"), mean, .names = c("mean_sepal_length", "mean_sepal_width")))
#'
#' @export
across <- function(.cols = everything(), .fns = NULL, ..., .names = NULL) {
setup <- setup_across(substitute(.cols), .fns, .names)
cols <- setup$cols
n_cols <- length(cols)
if (n_cols == 0L) return(data.frame())
funs <- setup$funs
data <- context$get_columns(cols)
names <- setup$names
if (is.null(funs)) {
data <- data.frame(data)
if (is.null(names)) {
return(data)
} else {
return(setNames(data, names))
}
}
n_fns <- length(funs)
res <- vector(mode = "list", length = n_fns * n_cols)
k <- 1L
for (i in seq_len(n_cols)) {
col <- data[[i]]
for (j in seq_len(n_fns)) {
res[[k]] <- funs[[j]](col, ...)
k <- k + 1L
}
}
if (is.null(names(res))) names(res) <- names
as.data.frame(res)
}

# -- helpers -----------------------------------------------------------------------------------------------------------

setup_across <- function(.cols, .fns, .names) {
cols <- eval_select_pos(.data = context$.data, .cols, .group_pos = FALSE)
cols <- context$get_colnames()[cols]
if (context$is_grouped()) cols <- setdiff(cols, get_groups(context$.data))

funs <- if (is.null(.fns)) NULL else if (!is.list(.fns)) list(.fns) else .fns
f_nms <- names(funs)
if (is.null(f_nms) && !is.null(.fns)) names(funs) <- seq_along(funs)
if (any(nchar(f_nms) == 0L)) {
miss <- which(nchar(f_nms) == 0L)
names(funs)[miss] <- miss
f_nms <- names(funs)
}

names <- if (!is.null(.names)) {
.names
} else {
if (length(funs) == 1L && is.null(f_nms)) {
cols
} else {
nms <- do.call(paste, c(rev(expand.grid(names(funs), cols)), sep = "_"))
if (length(nms) == 0L) nms <- NULL
nms
}
}

list(cols = cols, funs = funs, names = names)
}
7 changes: 7 additions & 0 deletions R/context.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ context <- new.env()
# Data
context$setup <- function(.data) context$.data <- .data
context$get_data <- function() context$.data
context$get_columns <- function(cols) context$.data[, cols, drop = FALSE]
context$get_nrow <- function() nrow(context$.data)
context$get_colnames <- function() colnames(context$.data)
context$is_grouped <- function() has_groups(context$.data)
context$clean <- function() rm(list = c(".data"), envir = context)

#' Context dependent expressions
Expand Down Expand Up @@ -55,6 +57,7 @@ n <- function() {
#' @description
#' * `cur_data()` gives the current data for the current group (excluding grouping variables).
#' @rdname context
#' @export
cur_data <- function() {
check_group_context("`cur_data()`")
data <- context$get_data()
Expand All @@ -64,6 +67,7 @@ cur_data <- function() {
#' @description
#' * `cur_data_all()` gives the current data for the current group (including grouping variables).
#' @rdname context
#' @export
cur_data_all <- function() {
check_group_context("`cur_data_all()`")
ungroup(context$get_data())
Expand All @@ -73,6 +77,7 @@ cur_data_all <- function() {
#' * `cur_group()` gives the group keys, a single row `data.frame` containing a column for each grouping variable and
#' its value.
#' @rdname context
#' @export
cur_group <- function() {
check_group_context("`cur_group()`")
data <- context$get_data()
Expand All @@ -84,6 +89,7 @@ cur_group <- function() {
#' @description
#' * `cur_group_id()` gives a unique numeric identifier for the current group.
#' @rdname context
#' @export
cur_group_id <- function() {
check_group_context("`cur_group_id()`")
data <- context$get_data()
Expand All @@ -97,6 +103,7 @@ cur_group_id <- function() {
#' @description
#' * `cur_group_rows()` gives the rows the groups appear in the data.
#' @rdname context
#' @export
cur_group_rows <- function() {
check_group_context("`cur_group_rows()`")
data <- context$get_data()
Expand Down
2 changes: 1 addition & 1 deletion R/group_split.R
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ group_split <- function(.data, ..., .keep = TRUE) {
group_keys <- function(.data) {
groups <- get_groups(.data)
context$setup(.data)
res <- context$.data[, context$get_colnames() %in% groups, drop = FALSE]
res <- context$get_columns(context$get_colnames() %in% groups)
res <- res[!duplicated(res), , drop = FALSE]
if (nrow(res) == 0L) return(res)
class(res) <- "data.frame"
Expand Down
6 changes: 4 additions & 2 deletions R/mutate.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@ mutate <- function(.data, ...) {
#' @export
mutate.default <- function(.data, ...) {
conditions <- dotdotdot(..., .impute_names = TRUE)
.data[, setdiff(names(conditions), names(.data))] <- NA
context$setup(.data)
on.exit(context$clean(), add = TRUE)
for (i in seq_along(conditions)) {
context$.data[, names(conditions)[i]] <- eval(conditions[[i]], envir = context$.data)
res <- eval(conditions[[i]], envir = context$.data)
if (!is.list(res)) res <- list(res)
if (is.null(names(res))) names(res) <- names(conditions)[[i]]
context$.data[, names(res)] <- res
}
context$.data
}
Expand Down
2 changes: 1 addition & 1 deletion R/select_positions.R
Original file line number Diff line number Diff line change
Expand Up @@ -168,5 +168,5 @@ select_env$get_ncol <- function() ncol(select_env$.data)
#' A cleaner interface to evaluating select_positions when column names are not passed via ...
#' @noRd
eval_select_pos <- function(.data, .cols, .group_pos = FALSE) {
do.call(select_positions, list(.data = .data, .group_pos = .group_pos, .cols))
do.call(select_positions, list(.data = .data, .cols, .group_pos = .group_pos))
}
23 changes: 12 additions & 11 deletions R/summarise.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,21 @@ summarise.default <- function(.data, ...) {
fns <- dotdotdot(...)
context$setup(.data)
on.exit(context$clean(), add = TRUE)
groups_exist <- has_groups(context$.data)
groups_exist <- context$is_grouped()
if (groups_exist) {
group <- unique(context$.data[, get_groups(context$.data), drop = FALSE])
group <- unique(context$get_columns(get_groups(context$.data)))
}
res <- lapply(
fns,
function(x) {
x_res <- do.call(with, list(context$.data, x))
if (is.list(x_res)) I(x_res) else x_res
res <- vector(mode = "list", length = length(fns))
for (i in seq_along(fns)) {
out <- do.call(with, list(context$.data, fns[[i]]))
nms <- if (!is_named(out)) {
if (!is.null(names(fns)[[i]])) names(fns)[[i]] else deparse(fns[[i]])
} else {
NULL
}
)
res <- as.data.frame(res)
fn_names <- names(fns)
colnames(res) <- if (is.null(fn_names)) fns else fn_names
res[[i]] <- build_data_frame(out, nms)
}
res <- do.call(cbind, res)
if (groups_exist) res <- cbind(group, res, row.names = NULL)
res
}
Expand Down
14 changes: 14 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,17 @@ is_named <- function(x) {
names_are_invalid <- function(x) {
x == "" | is.na(x)
}

#' Build a `data.frame` from a variety of inputs including atomic vectors, lists and other `data.frame`s
#' @noRd
build_data_frame <- function(x, nms = NULL) {
res <- if (is.atomic(x)) {
data.frame(x)
} else if (is.list(x) && !is.data.frame(x)) {
data.frame(I(x))
} else if (is.data.frame(x)) {
x
}
if (!is.null(nms)) colnames(res) <- nms
res
}
1 change: 1 addition & 0 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ reference:

- title: Vector functions
contents:
- across
- between
- case_when
- coalesce
Expand Down
Loading

0 comments on commit f890578

Please sign in to comment.