diff --git a/NEWS.md b/NEWS.md index 0a0b6aed..ed2dbb5a 100644 --- a/NEWS.md +++ b/NEWS.md @@ -8,6 +8,7 @@ - improved interface for model initialization / optimisation parameters, which are now passed on to jackknife / bootstrap post-treatments - better support of GPU when using torch backend +* Change behavior of `predict()` function for PLNfit model to (i) return fitted values if newdata is missing or (ii) perform one VE step to improve fit if responses are provided (fix issue #114) # PLNmodels 1.0.4 (2023-08-24) diff --git a/R/PLNfit-S3methods.R b/R/PLNfit-S3methods.R index b0015562..8fd95053 100644 --- a/R/PLNfit-S3methods.R +++ b/R/PLNfit-S3methods.R @@ -13,13 +13,15 @@ isPLNfit <- function(Robject) {inherits(Robject, "PLNfit" )} #' #' @param object an R6 object with class [`PLNfit`] #' @param newdata A data frame in which to look for variables and offsets with which to predict +#' @param responses Optional data frame containing the count of the observed variables (matching the names of the provided as data in the PLN function), assuming the interest in in testing the model. #' @param type The type of prediction required. The default is on the scale of the linear predictors (i.e. log average count) +#' @param level Optional integer value the level to be used in obtaining the predictions. Level zero corresponds to the population predictions (default if `responses` is not provided) while level one (default) corresponds to predictions after evaluating the variational parameters for the new data. #' @param ... additional parameters for S3 compatibility. Not used #' @return A matrix of predicted log-counts (if `type = "link"`) or predicted counts (if `type = "response"`). #' @export -predict.PLNfit <- function(object, newdata, type = c("link", "response"), ...) { +predict.PLNfit <- function(object, newdata, responses = NULL, level = 1, type = c("link", "response"), ...) { stopifnot(isPLNfit(object)) - object$predict(newdata, type, parent.frame()) + object$predict(newdata = newdata, type = type, envir = parent.frame(), level = level, responses = responses) } #' Predict counts conditionally diff --git a/R/PLNfit-class.R b/R/PLNfit-class.R index 98c5fa54..dc2c6939 100644 --- a/R/PLNfit-class.R +++ b/R/PLNfit-class.R @@ -506,23 +506,62 @@ PLNfit <- R6Class( #' @description Predict position, scores or observations of new data. #' @param newdata A data frame in which to look for variables with which to predict. If omitted, the fitted values are used. + #' @param responses Optional data frame containing the count of the observed variables (matching the names of the provided as data in the PLN function), assuming the interest in in testing the model. #' @param type Scale used for the prediction. Either `link` (default, predicted positions in the latent space) or `response` (predicted counts). + #' @param level Optional integer value the level to be used in obtaining the predictions. Level zero corresponds to the population predictions (default if `responses` is not provided) while level one (default) corresponds to predictions after evaluating the variational parameters for the new data. #' @param envir Environment in which the prediction is evaluated + #' + #' @details + #' Note that `level = 1` can only be used if responses are provided, + #' as the variational parameters can't be estimated otherwise. In the absence of responses, `level` is ignored and the fitted values are returned #' @return A matrix with predictions scores or counts. - predict = function(newdata, type = c("link", "response"), envir = parent.frame()) { + predict = function(newdata, responses = NULL, type = c("link", "response"), level = 1, envir = parent.frame()) { + + ## Ignore everything if newdata is not provided + if (missing(newdata)) { + return(self$fitted) + } + + n_new <- nrow(newdata) + ## Set level to 0 (to bypass VE step) if responses are not provided + if (is.null(responses)) { + level <- 0 + } ## Extract the model matrices from the new data set with initial formula X <- model.matrix(formula(private$formula)[-2], newdata, xlev = attr(private$formula, "xlevels")) O <- model.offset(model.frame(formula(private$formula)[-2], newdata)) + if (is.null(O)) O <- matrix(0, n_new, self$p) - ## mean latent positions in the parameter space - EZ <- X %*% private$B - if (!is.null(O)) EZ <- EZ + O - EZ <- sweep(EZ, 2, .5 * diag(self$model_par$Sigma), "+") + ## mean latent positions in the parameter space (covariates/offset only) + EZ <- X %*% private$B + O + rownames(EZ) <- rownames(newdata) colnames(EZ) <- colnames(private$Sigma) + ## Optimize M and S if responses are provided, + if (level == 1) { + VE <- self$optimize_vestep( + covariates = X, + offsets = O, + responses = as.matrix(responses), + weights = rep(1, n_new), + B = private$B, + Omega = private$Omega + ) + M <- VE$M + S <- VE$S + } else { + # otherwise set M = 0 and S = diag(Sigma) + M <- matrix(1, nrow = n_new, ncol = self$p) + S <- matrix(diag(private$Sigma), nrow = n_new, ncol = self$p, byrow = TRUE) + } + type <- match.arg(type) - results <- switch(type, link = EZ, response = exp(EZ)) + results <- switch( + type, + link = EZ + M, + response = exp(EZ + M + 0.5 * S) + ) attr(results, "type") <- type results }, diff --git a/man/PLNfit.Rd b/man/PLNfit.Rd index 2703ac82..303c21cb 100644 --- a/man/PLNfit.Rd +++ b/man/PLNfit.Rd @@ -285,7 +285,13 @@ The list of parameters \code{config} controls the post-treatment processing, wit \subsection{Method \code{predict()}}{ Predict position, scores or observations of new data. \subsection{Usage}{ -\if{html}{\out{