Skip to content

Commit

Permalink
Merge pull request #22 from bedapub/feat/multitrace
Browse files Browse the repository at this point in the history
read-only BatchContainer and trace as tibble
  • Loading branch information
idavydov authored Sep 1, 2023
2 parents 1a559a0 + d5cc052 commit 37fdcb8
Show file tree
Hide file tree
Showing 78 changed files with 1,275 additions and 1,415 deletions.
9 changes: 3 additions & 6 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,13 @@ Suggests:
tidyverse,
printr,
devtools (>= 2.0.0),
gridpattern,
ggpattern,
cowplot,
bestNormalize
bestNormalize,
here
Encoding: UTF-8
LazyData: true
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.1
RoxygenNote: 7.2.3
VignetteBuilder: knitr
biocViews:
Remotes:
github::trevorld/gridpattern,
github::coolbutuseless/ggpattern
1 change: 0 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ export(BatchContainer)
export(BatchContainerDimension)
export(L1_norm)
export(L2s_norm)
export(OptimizationTrace)
export(accept_leftmost_improvement)
export(as_label)
export(as_name)
Expand Down
45 changes: 45 additions & 0 deletions R/all_equal_df.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#' Compare two data.frames.
#'
#' This will convert factors to characters and disregard
#' row and column order
#'
#' @param df1 first [data.frame()] to compare
#' @param df2 second `data.frame()` to compare
#' @return `TRUE` or `FALSE` in case differences are present
#' @keywords internal
all_equal_df <- function(df1, df2) {
if (!is.data.frame(df1) || !is.data.frame(df2)) {
return(FALSE)
}

if (nrow(df1) != nrow(df2) || ncol(df1) != ncol(df2)) {
return(FALSE)
}

assertthat::assert_that(
!any(duplicated(colnames(df1))),
!any(duplicated(colnames(df2))),
msg = "duplicated colnames"
)

df2 <- df2[colnames(df1)]

# convert factors to characters
df1 <- df1 |>
dplyr::mutate(dplyr::across(dplyr::where(is.factor), as.character))
df2 <- df2 |>
dplyr::mutate(dplyr::across(dplyr::where(is.factor), as.character))

# order by all columns
df1 <- df1[do.call(order, df1),]
df2 <- df2[do.call(order, df2),]

# remove row names
rownames(df1) <- NULL
rownames(df2) <- NULL

assertthat::are_equal(
all.equal(df1, df2, check.attributes = FALSE),
TRUE
)
}
26 changes: 13 additions & 13 deletions R/assignment.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
#' @param samples data.frame with samples.
#' @param batch_container Instance of BatchContainer class
#'
#' @return Returns `BatchContainer`, invisibly.
#' @return Returns a new `BatchContainer`.
#' @example man/examples/assignment.R
assign_random <- function(batch_container, samples = NULL) {
assign_in_order(batch_container, samples)
batch_container <- assign_in_order(batch_container, samples)

batch_container$move_samples(
location_assignment = sample(batch_container$assignment)
)

invisible(batch_container)
batch_container
}

#' Distributes samples in order.
Expand All @@ -25,9 +25,10 @@ assign_random <- function(batch_container, samples = NULL) {
#' @param samples data.frame with samples.
#' @param batch_container Instance of BatchContainer class
#'
#' @return Returns `BatchContainer`, invisibly.
#' @return Returns a new `BatchContainer`.
#' @example man/examples/assignment.R
assign_in_order <- function(batch_container, samples = NULL) {
batch_container <- batch_container$copy()
if (is.null(samples)) {
assertthat::assert_that(batch_container$has_samples,
msg = "batch-container is empty and no samples provided"
Expand All @@ -46,7 +47,7 @@ assign_in_order <- function(batch_container, samples = NULL) {
rep(NA_integer_, n_locations - n_samples)
))

invisible(batch_container)
batch_container
}

#' Shuffling proposal function with constraints.
Expand Down Expand Up @@ -113,7 +114,7 @@ shuffle_with_constraints <- function(src = TRUE, dst = TRUE) {
#' the function will check if samples in `batch_container` are identical to the ones in the
#' `samples` argument.
#'
#' @return Returns `BatchContainer`, invisibly.
#' @return Returns a new `BatchContainer`.
#'
#' @examples
#' bc <- BatchContainer$new(
Expand All @@ -133,11 +134,12 @@ shuffle_with_constraints <- function(src = TRUE, dst = TRUE) {
#' 2, "a", 3, 5, "TRT",
#' )
#' # assign samples from the sample sheet
#' assign_from_table(bc, sample_sheet)
#' bc <- assign_from_table(bc, sample_sheet)
#'
#' bc$get_samples(remove_empty_locations = TRUE)
#'
assign_from_table <- function(batch_container, samples) {
batch_container <- batch_container$copy()
# sample sheet has all the batch variable
assertthat::assert_that(is.data.frame(samples) && nrow(samples) > 0,
msg = "samples should be non-empty data.frame"
Expand All @@ -156,11 +158,9 @@ assign_from_table <- function(batch_container, samples) {
if (is.null(batch_container$samples)) {
batch_container$samples <- only_samples
} else {
assertthat::assert_that(dplyr::all_equal(only_samples,
batch_container$get_samples(assignment = FALSE),
ignore_col_order = TRUE,
ignore_row_order = TRUE,
convert = TRUE
assertthat::assert_that(all_equal_df(
only_samples,
batch_container$get_samples(assignment = FALSE)
),
msg = "sample sheet should be compatible with samples inside the batch container"
)
Expand All @@ -177,5 +177,5 @@ assign_from_table <- function(batch_container, samples) {

batch_container$move_samples(location_assignment = samples_with_id$.sample_id)

invisible(batch_container)
batch_container
}
183 changes: 153 additions & 30 deletions R/batch_container.R
Original file line number Diff line number Diff line change
Expand Up @@ -328,23 +328,67 @@ BatchContainer <- R6::R6Class("BatchContainer",

#' @description
#' Score current sample assignment,
#' @return Returns a vector of all scoring functions values.
score = function() {
assertthat::assert_that(!is.null(private$scoring_funcs),
msg = "Scoring function needs to be assigned"
#' @param scoring a function or a names list of scoring functions.
#' Each function should return a numeric vector.
#' @return Returns a named vector of all scoring functions values.
score = function(scoring) {
assertthat::assert_that(
!missing(scoring),
!is.null(scoring),
msg = "Scoring function needs to be provided"
)
assertthat::assert_that(is.list(private$scoring_funcs),
length(private$scoring_funcs) >= 1,
msg = "Scroring function should be a non-empty list"
if (is.function(scoring)) {
scoring <- list(scoring)
} else {
assertthat::assert_that(is.list(scoring), length(scoring) >= 1)
assertthat::assert_that(
all(purrr::map_lgl(scoring, is.function)),
msg = "All elements of scoring should be functions"
)
}
if (is.null(names(scoring))) {
names(scoring) <- stringr::str_c("score_", seq_along(scoring))
}
assertthat::assert_that(
!any(names(scoring) == ""),
msg = "scoring cannot be a partially named list"
)
assertthat::assert_that(!is.null(private$samples_table),
assertthat::assert_that(is.list(scoring),
length(scoring) >= 1,
msg = "Scoring function should be a non-empty list"
)
assertthat::assert_that(!is.null(names(scoring)),
msg = "scoring should be a named list"
)
assertthat::assert_that(self$has_samples,
msg = "No samples in the batch container, cannot compute score"
)

res <- purrr::map_dbl(private$scoring_funcs, ~ .x(self))
assertthat::assert_that(length(res) == length(private$scoring_funcs))

assertthat::assert_that(is.numeric(res), msg = "Scoring function should return a number")
res <- purrr::imap(
scoring,
\(f, i) {
v <- f(self)
assertthat::assert_that(
is.numeric(v),
length(v) >= 1,
msg = "scoring function should return a numeric vector of positive length"
)
if (length(v) > 1) {
if (is.null(names(v))) {
names(v) <- seq_along(v)
}
names(v) <- stringr::str_c(i, names(v))
} else {
names(v) <- i
}
v
}
) |>
purrr::flatten_dbl()
assertthat::assert_that(length(res) >= length(scoring))
assertthat::assert_that(
!any(names(res) == "step"),
msg = "score name cannot be 'step'"
)

return(res)
},
Expand All @@ -368,7 +412,7 @@ BatchContainer <- R6::R6Class("BatchContainer",
bc$samples_attr <- private$samples_attributes
}

bc$scoring_f <- self$scoring_f
bc$trace <- self$trace
bc
},

Expand Down Expand Up @@ -398,6 +442,100 @@ BatchContainer <- R6::R6Class("BatchContainer",
cat()
cat("\n")
invisible(self)
},

#' @field trace Optimization trace, a [tibble::tibble()]
trace = tibble::tibble(
optimization_index = numeric(),
call = list(),
start_assignment_vec = list(),
end_assignment_vec = list(),
scores = list(),
aggregated_scores = list(),
seed = list(),
elapsed = as.difftime(character(0))
),

#' @description
#' Return a table with scores from an optimization.
#'
#' @param index optimization index, all by default
#' @param include_aggregated include aggregated scores
#' @return a [tibble::tibble()] with scores
scores_table = function(index = NULL, include_aggregated = FALSE) {
assertthat::assert_that(
tibble::is_tibble(self$trace),
nrow(self$trace) >= 1,
msg = "trace should be available"
)
assertthat::assert_that(assertthat::is.flag(include_aggregated))
if (is.null(index)) {
index <- self$trace$optimization_index
}
assertthat::assert_that(
rlang::is_integerish(index),
msg = "index should be an integer"
)
d <- self$trace %>%
dplyr::filter(.data$optimization_index %in% index) %>%
dplyr::select(.data$optimization_index, .data$scores) %>%
tidyr::unnest(.data$scores) %>%
tidyr::pivot_longer(c(-.data$optimization_index, -.data$step),
names_to = "score",
values_to = "value") %>%
dplyr::mutate(aggregated = FALSE)
if (include_aggregated) {
d_agg <- self$trace %>%
dplyr::filter(.data$optimization_index %in% index) %>%
dplyr::select(.data$optimization_index, .data$aggregated_scores) %>%
tidyr::unnest(.data$aggregated_scores)

if ("step" %in% colnames(d_agg)) {
# if no aggregated scores are provided (aggregated_scores=NULL),
# there will be no step column after unnesting
d_agg <- d_agg %>%
tidyr::pivot_longer(c(-.data$optimization_index, -.data$step),
names_to = "score",
values_to = "value") %>%
dplyr::mutate(
aggregated = TRUE,
score = paste0("agg.", .data$score)
)
d <- dplyr::bind_rows(
d,
d_agg
)
}
}
d
},

#' @description
#' Plot trace
#' @param index optimization index, all by default
#' @param include_aggregated include aggregated scores
#' @param ... not used.
#' @return a [ggplot2::ggplot()] object
plot_trace = function(index = NULL, include_aggregated = FALSE, ...) {
d <- self$scores_table(index, include_aggregated) %>%
dplyr::mutate(
agg_title = dplyr::if_else(.data$aggregated, "aggregated", "score")
)
p <- ggplot2::ggplot(d) +
ggplot2::aes(.data$step, .data$value, group = .data$score, color = .data$score) +
ggplot2::geom_line() +
ggplot2::geom_point()
if (length(unique(d$optimization_index)) > 1) {
p <- p +
ggplot2::facet_wrap(~ optimization_index, scales = "free")
} else if (include_aggregated && any(d$aggregated)) {
p <- p +
ggplot2::facet_wrap(~ agg_title, scales = "free_y", ncol = 1)
} else {
p <- p +
ggplot2::facet_wrap(~ score, scales = "free_y", ncol = 1)
}
p
}
),
private = list(
Expand Down Expand Up @@ -445,22 +583,7 @@ BatchContainer <- R6::R6Class("BatchContainer",
#' Upon assignment a single function will be automatically converted to a list
#' In the later case each function is called.
scoring_f = function(value) {
if (missing(value)) {
private$scoring_funcs
} else {
if (is.null(value)) {
private$scoring_funcs <- NULL
} else if (is.function(value)) {
private$scoring_funcs <- list(value)
} else {
assertthat::assert_that(is.list(value), length(value) >= 1)
assertthat::assert_that(
all(purrr::map_lgl(self$scoring_f, is.function)),
msg = "All elements of scoring_f should be functions"
)
private$scoring_funcs <- value
}
}
stop("scoring_f is deprecated, pass it to optimize_design() directly instead")
},

#' @field has_samples
Expand Down
Loading

0 comments on commit 37fdcb8

Please sign in to comment.