Skip to content

Commit

Permalink
Auto-detect quantile levels in gg.prediction. Partial fix for #127
Browse files Browse the repository at this point in the history
  • Loading branch information
finnlindgren committed May 10, 2022
1 parent 60c243c commit d8cd14b
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 71 deletions.
119 changes: 54 additions & 65 deletions R/ggplot.R
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,11 @@ gg.data.frame <- function(...) {
#' @importFrom utils modifyList
#' @param data A prediction object, usually the result of a [predict.bru()] call.
#' @param mapping a set of aesthetic mappings created by `aes`. These are passed on to `geom_line`.
#' @param ribbon If TRUE, plot a ribbon around the line based on the upper and lower 2.5 percent quantiles.
#' @param alpha The ribbons numeric alpha level in `[0,1]`.
#' @param ribbon If TRUE, plot a ribbon around the line based on the smalles and largest
#' quantiles present in the data, found by matching names starting with `q` and
#' followed by a numerical value. `inla()`-style `numeric+"quant"` names are converted
#' to inlabru style before matching.
#' @param alpha The ribbons numeric alpha (transparency) level in `[0,1]`.
#' @param bar If TRUE plot boxplot-style summary for each variable.
#' @param \dots Arguments passed on to `geom_line`.
#' @return Concatenation of a `geom_line` value and optionally a `geom_ribbon` value.
Expand All @@ -220,14 +223,23 @@ gg.data.frame <- function(...) {

gg.prediction <- function(data, mapping = NULL, ribbon = TRUE, alpha = 0.3, bar = FALSE, ...) {
requireNamespace("ggplot2")
if (all(c("0.025quant", "0.975quant") %in% names(data))) {
names(data)[names(data) == "0.975quant"] <- "q0.975"
names(data)[names(data) == "0.025quant"] <- "q0.025"
names(data)[names(data) == "0.5quant"] <- "median"
}
lqname <- "q0.025"
uqname <- "q0.975"
# Convert from old and inla style names
new_quant_names <- list(q0.025 = "0.025quant",
q0.5 = "0.5quant",
q0.975 = "0.975quant")

for (quant_name in names(new_quant_names[new_quant_names %in% names(data)])) {
names(data)[names(data) == new_quant_names[[quant_name]]] <- quant_name
}
# Find quantile levels
quant_names <- names(data)[grepl("^q[01]\\.?[0-9]*$", names(data))]
if (length(quant_names) > 0) {
quant_probs <- as.numeric(sub("^q", "", quant_names))
quant_names <- quant_names[order(quant_probs)]
quant_probs <- quant_probs[order(quant_probs)]
lqname <- quant_names[1]
uqname <- quant_names[length(quant_names)]
}

if (bar | (nrow(data) == 1)) {
sz <- 10 # bar width
Expand All @@ -252,52 +264,21 @@ gg.prediction <- function(data, mapping = NULL, ribbon = TRUE, alpha = 0.3, bar
color = .data$variable
),
shape = 95, size = 0
), # Fake ylab
ggplot2::geom_segment(
data = data,
mapping = ggplot2::aes(
y = .data[[lqname]],
yend = .data[[uqname]],
x = .data$variable,
xend = .data$variable,
color = .data$variable
),
size = sz
)
)

# Min and max sample
if (all(c("smax", "smax") %in% names(data))) {
)
if (length(quant_names) > 0) {
geom <- c(
geom,
ggplot2::geom_segment(
data = data,
mapping = ggplot2::aes(
y = .data$smin,
yend = .data$smax,
y = .data[[lqname]],
yend = .data[[uqname]],
x = .data$variable,
xend = .data$variable,
color = .data$variable
),
size = 1
),
ggplot2::geom_point(
data = data,
mapping = ggplot2::aes(
x = .data$variable,
y = .data$smax,
color = .data$variable
),
shape = 95, size = 5
),
ggplot2::geom_point(
data = data,
mapping = ggplot2::aes(
x = .data$variable,
y = .data$smin,
color = .data$variable
),
shape = 95, size = 5
size = sz
)
)
}
Expand All @@ -312,51 +293,59 @@ gg.prediction <- function(data, mapping = NULL, ribbon = TRUE, alpha = 0.3, bar
y = .data$mean
),
color = "black", shape = 95, size = sz
),
ggplot2::geom_point(
)
)
if ("q0.5" %in% quant_names) {
geom <- c(
geom,
ggplot2::geom_point(
data = data,
mapping = ggplot2::aes(
x = .data$variable,
y = .data$median
y = .data$q0.5
),
color = "black", shape = 20, size = med_sz
),
ggplot2::coord_flip()
)
)
)
}
geom <- c(geom,
ggplot2::coord_flip())
} else {
if ("pdf" %in% names(data)) {
y.str <- "pdf"
ribbon <- FALSE
} else if ("mean" %in% names(data)) {
y.str <- "mean"
} else if ("q0.5" %in% names(data)) {
y.str <- "q0.5"
} else if ("median" %in% names(data)) {
y.str <- "median"
y.str <- "q0.5"
} else {
stop("Prediction has neither mean nor median or pdf as column. Don't know what to plot.")
stop("Prediction has neither mean nor median/q0.5 or pdf as column. Don't know what to plot.")
}

line.map <- ggplot2::aes(
x = .data[[names(data)[1]]],
y = .data[[y.str]]
)

ribbon.map <- ggplot2::aes(
x = .data[[names(data)[1]]],
ymin = .data[[lqname]],
ymax = .data[[uqname]]
)

if (!is.null(mapping)) {
line.map <- utils::modifyList(line.map, mapping)
}

# Use line color for ribbon filling
if ("colour" %in% names(line.map)) {
ribbon.map <- modifyList(ribbon.map, ggplot2::aes(fill = .data[[line.map[["colour"]]]]))
}

geom <- ggplot2::geom_line(data = data, line.map, ...)
if (ribbon) {

if (ribbon && length(quant_names) > 0) {
ribbon.map <- ggplot2::aes(
x = .data[[names(data)[1]]],
ymin = .data[[lqname]],
ymax = .data[[uqname]]
)
# Use line color for ribbon filling
if ("colour" %in% names(line.map)) {
ribbon.map <- modifyList(ribbon.map,
ggplot2::aes(fill = .data[[line.map[["colour"]]]]))
}
geom <- c(geom, ggplot2::geom_ribbon(data = data, ribbon.map, alpha = alpha))
}
}
Expand Down
7 changes: 5 additions & 2 deletions man/gg.prediction.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 2 additions & 4 deletions tests/testthat/test_predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,7 @@ test_that("bru: factor component", {
)

# The statistics include mean, standard deviation, the 2.5% quantile, the median,
# the 97.5% quantile, minimum and maximum sample drawn from the posterior as well as
# the coefficient of variation and the variance.

# the 97.5% quantile
expect_equal(is.data.frame(xpost), TRUE)
expect_equal(nrow(xpost), 1)

Expand All @@ -71,7 +69,7 @@ test_that("bru: factor component", {
# The predict function can also be used to simultaneously estimate posteriors
# of multiple variables:

xipost <- generate(fit,
xipost <- predict(fit,
data = NULL,
formula = ~ c(
Intercept = Intercept_latent,
Expand Down

0 comments on commit d8cd14b

Please sign in to comment.