From ff8b1465825548da9910eda097e69143304f62d1 Mon Sep 17 00:00:00 2001 From: merliseclyde Date: Mon, 9 Nov 2020 23:54:19 -0500 Subject: [PATCH] created function to extract the MPM related to issue #53 --- R/extract_models.R | 70 ++++++++++++++++++++++++++++ man/extract_MPM.Rd | 55 ++++++++++++++++++++++ tests/testthat/test-extract_models.R | 36 ++++++++++++++ 3 files changed, 161 insertions(+) create mode 100644 R/extract_models.R create mode 100644 man/extract_MPM.Rd create mode 100644 tests/testthat/test-extract_models.R diff --git a/R/extract_models.R b/R/extract_models.R new file mode 100644 index 00000000..f6757ca9 --- /dev/null +++ b/R/extract_models.R @@ -0,0 +1,70 @@ +#' Extract the Median Probability Model +#' @description Extracts the Median Probability Model from a bas object +#' @param object An object of class "bas" or "basglm" +#' @return a new object with of class "bas" or "basglm" with the Median +#' Probability Model +#' @details The Median Probability Model is the model where variables are +#' included if the marginal posterior probabilty of the coefficient being +#' zero is greater than 0.5. As this model may not have been sampled (and even +#' if it has) it is oftern faster to refit the model using bas, rather than +#' search the list of models to see where it was included. +#' @examples +#' data(Hald, package=BAS) +#' hald_bic = bas.lm(Y ~ ., data=Hald, alpha=13, prior="BIC") +#' extract_MPM(hald_bic) +#' +#' data(Pima.tr, package="MASS") +#' Pima_bas = bas.glm(type ~ ., data=Pima.tr, n.models= 2^7, method="BAS", +#' betaprior=CCH(a=1, b=nrow(Pima.tr)/2, s=0), family=binomial(), +#' modelprior=uniform()) +#' extract_MPM(Pima_bas) +#' @family bas methods +#' @export +extract_MPM = function(object) { +# if (!(class(object) %in% c("basglm", "bas"))) { +# stop("requires an object of class 'bas' or 'basglm'") } + nvar <- object$n.vars - 1 + bestmodel <- as.numeric(object$probne0 > .5) + + if (is.null(object$call$weights)) { + object$call$weights = NULL } + + if ( !("basglm" %in% class(object))) { + # call lm + newobject <- bas.lm( + eval(object$call$formula), + data = eval(object$call$data, parent.frame()), + weights = eval(object$call$weights), + n.models = 1, + alpha = object$g, + initprobs = object$probne0, + prior = object$prior, + modelprior = object$modelprior, + update = NULL, + bestmodel = bestmodel + ) + + } +else { + glm_family = eval(object$family, parent.frame())$family + family <- get(glm_family, mode = "function", envir = parent.frame()) + newobject <- bas.glm( + eval(object$call$formula), + data = eval(object$call$data, parent.frame()), + weights = eval(object$call$weights), + family = family, + n.models = 1L, + initprobs = object$probne0, + betaprior = object$betaprior, + modelprior = object$modelprior, + update = NULL, + bestmodel = bestmodel + ) +} + newobject$probne0 = object$probne0 + mf = object$call + mf$n.models = 1 + mf$bestmodel = bestmodel + newobject$call = mf + return(newobject) +} diff --git a/man/extract_MPM.Rd b/man/extract_MPM.Rd new file mode 100644 index 00000000..d64b59e3 --- /dev/null +++ b/man/extract_MPM.Rd @@ -0,0 +1,55 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/extract_models.R +\name{extract_MPM} +\alias{extract_MPM} +\title{Extract the Median Probability Model} +\usage{ +extract_MPM(object) +} +\arguments{ +\item{object}{An object of class "bas" or "basglm"} +} +\value{ +a new object with of class "bas" or "basglm" with the Median +Probability Model +} +\description{ +Extracts the Median Probability Model from a bas object +} +\details{ +The Median Probability Model is the model where variables are +included if the marginal posterior probabilty of the coefficient being +zero is greater than 0.5. As this model may not have been sampled (and even +if it has) it is oftern faster to refit the model using bas, rather than +search the list of models to see where it was included. +} +\examples{ +data(Hald, package=BAS) +hald_bic = bas.lm(Y ~ ., data=Hald, alpha=13, prior="BIC") +extract_MPM(hald_bic) + +data(Pima.tr, package="MASS") +Pima_bas = bas.glm(type ~ ., data=Pima.tr, n.models= 2^7, method="BAS", + betaprior=CCH(a=1, b=nrow(Pima.tr)/2, s=0), family=binomial(), + modelprior=uniform()) +extract_MPM(Pima_bas) +} +\seealso{ +Other bas methods: +\code{\link{BAS}}, +\code{\link{bas.lm}()}, +\code{\link{coef.bas}()}, +\code{\link{confint.coef.bas}()}, +\code{\link{confint.pred.bas}()}, +\code{\link{diagnostics}()}, +\code{\link{fitted.bas}()}, +\code{\link{force.heredity.bas}()}, +\code{\link{image.bas}()}, +\code{\link{plot.confint.bas}()}, +\code{\link{predict.basglm}()}, +\code{\link{predict.bas}()}, +\code{\link{summary.bas}()}, +\code{\link{update.bas}()}, +\code{\link{variable.names.pred.bas}()} +} +\concept{bas methods} diff --git a/tests/testthat/test-extract_models.R b/tests/testthat/test-extract_models.R new file mode 100644 index 00000000..ffe6dc13 --- /dev/null +++ b/tests/testthat/test-extract_models.R @@ -0,0 +1,36 @@ +test_that("extract Median Probability Model", { + + data(Hald, package="BAS") + hald_bic = bas.lm(Y ~ ., data=Hald, alpha=13, prior="BIC", + modelprior = uniform()) + hald_MPM_manual = bas.lm(Y ~ ., data=Hald, alpha=13, prior="BIC", + modelprior = uniform(), + n.models = 1L, + bestmodel = as.numeric(hald_bic$probne0 > .5) + ) + hald_MPM = extract_MPM(hald_bic) + expect_equal(hald_bic$n.vars, hald_MPM$n.vars) + expect_equal(as.numeric(hald_bic$probne0 > .5), + as.vector(which.matrix(hald_MPM$which[1], hald_MPM$n.vars))) + expect_equal(predict(hald_bic, estimator="MPM")$fit, + predict(hald_MPM)$fit, + check.attributes = FALSE) + + data(Pima.tr, package="MASS") + Pima_bas = bas.glm(type ~ ., data=Pima.tr, n.models= 2^7, method="BAS", + betaprior=CCH(a=1, b=nrow(Pima.tr)/2, s=0), + family=binomial(), + modelprior=uniform()) + Pima_MPM_man = bas.glm(type ~ ., data=Pima.tr, method="BAS", + betaprior=CCH(a=1, b=nrow(Pima.tr)/2, s=0), + family=binomial(), + modelprior=uniform(), + n.models = 1L, + bestmodel = Pima_bas$probne0 > 0.5) + + + Pima_MPM = extract_MPM(Pima_bas) + expect_equal(as.numeric(Pima_bas$probne0 > .5), + as.vector(which.matrix(Pima_MPM$which[1], Pima_MPM$n.vars))) + expect_equal(coef(Pima_MPM)$coef, coef(Pima_MPM_man)$coef) +})