diff --git a/R/ggplot.R b/R/ggplot.R index d9564a4ce..15b255551 100644 --- a/R/ggplot.R +++ b/R/ggplot.R @@ -210,8 +210,11 @@ gg.data.frame <- function(...) { #' @importFrom utils modifyList #' @param data A prediction object, usually the result of a [predict.bru()] call. #' @param mapping a set of aesthetic mappings created by `aes`. These are passed on to `geom_line`. -#' @param ribbon If TRUE, plot a ribbon around the line based on the upper and lower 2.5 percent quantiles. -#' @param alpha The ribbons numeric alpha level in `[0,1]`. +#' @param ribbon If TRUE, plot a ribbon around the line based on the smalles and largest +#' quantiles present in the data, found by matching names starting with `q` and +#' followed by a numerical value. `inla()`-style `numeric+"quant"` names are converted +#' to inlabru style before matching. +#' @param alpha The ribbons numeric alpha (transparency) level in `[0,1]`. #' @param bar If TRUE plot boxplot-style summary for each variable. #' @param \dots Arguments passed on to `geom_line`. #' @return Concatenation of a `geom_line` value and optionally a `geom_ribbon` value. @@ -220,14 +223,23 @@ gg.data.frame <- function(...) { gg.prediction <- function(data, mapping = NULL, ribbon = TRUE, alpha = 0.3, bar = FALSE, ...) { requireNamespace("ggplot2") - if (all(c("0.025quant", "0.975quant") %in% names(data))) { - names(data)[names(data) == "0.975quant"] <- "q0.975" - names(data)[names(data) == "0.025quant"] <- "q0.025" - names(data)[names(data) == "0.5quant"] <- "median" - } - lqname <- "q0.025" - uqname <- "q0.975" + # Convert from old and inla style names + new_quant_names <- list(q0.025 = "0.025quant", + q0.5 = "0.5quant", + q0.975 = "0.975quant") + for (quant_name in names(new_quant_names[new_quant_names %in% names(data)])) { + names(data)[names(data) == new_quant_names[[quant_name]]] <- quant_name + } + # Find quantile levels + quant_names <- names(data)[grepl("^q[01]\\.?[0-9]*$", names(data))] + if (length(quant_names) > 0) { + quant_probs <- as.numeric(sub("^q", "", quant_names)) + quant_names <- quant_names[order(quant_probs)] + quant_probs <- quant_probs[order(quant_probs)] + lqname <- quant_names[1] + uqname <- quant_names[length(quant_names)] + } if (bar | (nrow(data) == 1)) { sz <- 10 # bar width @@ -252,52 +264,21 @@ gg.prediction <- function(data, mapping = NULL, ribbon = TRUE, alpha = 0.3, bar color = .data$variable ), shape = 95, size = 0 - ), # Fake ylab - ggplot2::geom_segment( - data = data, - mapping = ggplot2::aes( - y = .data[[lqname]], - yend = .data[[uqname]], - x = .data$variable, - xend = .data$variable, - color = .data$variable - ), - size = sz ) - ) - - # Min and max sample - if (all(c("smax", "smax") %in% names(data))) { + ) + if (length(quant_names) > 0) { geom <- c( geom, ggplot2::geom_segment( data = data, mapping = ggplot2::aes( - y = .data$smin, - yend = .data$smax, + y = .data[[lqname]], + yend = .data[[uqname]], x = .data$variable, xend = .data$variable, color = .data$variable ), - size = 1 - ), - ggplot2::geom_point( - data = data, - mapping = ggplot2::aes( - x = .data$variable, - y = .data$smax, - color = .data$variable - ), - shape = 95, size = 5 - ), - ggplot2::geom_point( - data = data, - mapping = ggplot2::aes( - x = .data$variable, - y = .data$smin, - color = .data$variable - ), - shape = 95, size = 5 + size = sz ) ) } @@ -312,27 +293,35 @@ gg.prediction <- function(data, mapping = NULL, ribbon = TRUE, alpha = 0.3, bar y = .data$mean ), color = "black", shape = 95, size = sz - ), - ggplot2::geom_point( + ) + ) + if ("q0.5" %in% quant_names) { + geom <- c( + geom, + ggplot2::geom_point( data = data, mapping = ggplot2::aes( x = .data$variable, - y = .data$median + y = .data$q0.5 ), color = "black", shape = 20, size = med_sz - ), - ggplot2::coord_flip() - ) + ) + ) + } + geom <- c(geom, + ggplot2::coord_flip()) } else { if ("pdf" %in% names(data)) { y.str <- "pdf" ribbon <- FALSE } else if ("mean" %in% names(data)) { y.str <- "mean" + } else if ("q0.5" %in% names(data)) { + y.str <- "q0.5" } else if ("median" %in% names(data)) { - y.str <- "median" + y.str <- "q0.5" } else { - stop("Prediction has neither mean nor median or pdf as column. Don't know what to plot.") + stop("Prediction has neither mean nor median/q0.5 or pdf as column. Don't know what to plot.") } line.map <- ggplot2::aes( @@ -340,23 +329,23 @@ gg.prediction <- function(data, mapping = NULL, ribbon = TRUE, alpha = 0.3, bar y = .data[[y.str]] ) - ribbon.map <- ggplot2::aes( - x = .data[[names(data)[1]]], - ymin = .data[[lqname]], - ymax = .data[[uqname]] - ) - if (!is.null(mapping)) { line.map <- utils::modifyList(line.map, mapping) } - # Use line color for ribbon filling - if ("colour" %in% names(line.map)) { - ribbon.map <- modifyList(ribbon.map, ggplot2::aes(fill = .data[[line.map[["colour"]]]])) - } - geom <- ggplot2::geom_line(data = data, line.map, ...) - if (ribbon) { + + if (ribbon && length(quant_names) > 0) { + ribbon.map <- ggplot2::aes( + x = .data[[names(data)[1]]], + ymin = .data[[lqname]], + ymax = .data[[uqname]] + ) + # Use line color for ribbon filling + if ("colour" %in% names(line.map)) { + ribbon.map <- modifyList(ribbon.map, + ggplot2::aes(fill = .data[[line.map[["colour"]]]])) + } geom <- c(geom, ggplot2::geom_ribbon(data = data, ribbon.map, alpha = alpha)) } } diff --git a/man/gg.prediction.Rd b/man/gg.prediction.Rd index 34807caa7..20f2ee1e2 100644 --- a/man/gg.prediction.Rd +++ b/man/gg.prediction.Rd @@ -11,9 +11,12 @@ \item{mapping}{a set of aesthetic mappings created by \code{aes}. These are passed on to \code{geom_line}.} -\item{ribbon}{If TRUE, plot a ribbon around the line based on the upper and lower 2.5 percent quantiles.} +\item{ribbon}{If TRUE, plot a ribbon around the line based on the smalles and largest +quantiles present in the data, found by matching names starting with \code{q} and +followed by a numerical value. \code{inla()}-style \code{numeric+"quant"} names are converted +to inlabru style before matching.} -\item{alpha}{The ribbons numeric alpha level in \verb{[0,1]}.} +\item{alpha}{The ribbons numeric alpha (transparency) level in \verb{[0,1]}.} \item{bar}{If TRUE plot boxplot-style summary for each variable.} diff --git a/tests/testthat/test_predict.R b/tests/testthat/test_predict.R index d26051477..31046638f 100644 --- a/tests/testthat/test_predict.R +++ b/tests/testthat/test_predict.R @@ -57,9 +57,7 @@ test_that("bru: factor component", { ) # The statistics include mean, standard deviation, the 2.5% quantile, the median, - # the 97.5% quantile, minimum and maximum sample drawn from the posterior as well as - # the coefficient of variation and the variance. - + # the 97.5% quantile expect_equal(is.data.frame(xpost), TRUE) expect_equal(nrow(xpost), 1) @@ -71,7 +69,7 @@ test_that("bru: factor component", { # The predict function can also be used to simultaneously estimate posteriors # of multiple variables: - xipost <- generate(fit, + xipost <- predict(fit, data = NULL, formula = ~ c( Intercept = Intercept_latent,