Skip to content

Commit

Permalink
created function to extract the MPM related to issue #53
Browse files Browse the repository at this point in the history
  • Loading branch information
merliseclyde committed Nov 10, 2020
1 parent 656f3e9 commit ff8b146
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 0 deletions.
70 changes: 70 additions & 0 deletions R/extract_models.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#' Extract the Median Probability Model
#' @description Extracts the Median Probability Model from a bas object
#' @param object An object of class "bas" or "basglm"
#' @return a new object with of class "bas" or "basglm" with the Median
#' Probability Model
#' @details The Median Probability Model is the model where variables are
#' included if the marginal posterior probabilty of the coefficient being
#' zero is greater than 0.5. As this model may not have been sampled (and even
#' if it has) it is oftern faster to refit the model using bas, rather than
#' search the list of models to see where it was included.
#' @examples
#' data(Hald, package=BAS)
#' hald_bic = bas.lm(Y ~ ., data=Hald, alpha=13, prior="BIC")
#' extract_MPM(hald_bic)
#'
#' data(Pima.tr, package="MASS")
#' Pima_bas = bas.glm(type ~ ., data=Pima.tr, n.models= 2^7, method="BAS",
#' betaprior=CCH(a=1, b=nrow(Pima.tr)/2, s=0), family=binomial(),
#' modelprior=uniform())
#' extract_MPM(Pima_bas)
#' @family bas methods
#' @export
extract_MPM = function(object) {
# if (!(class(object) %in% c("basglm", "bas"))) {
# stop("requires an object of class 'bas' or 'basglm'") }
nvar <- object$n.vars - 1
bestmodel <- as.numeric(object$probne0 > .5)

if (is.null(object$call$weights)) {
object$call$weights = NULL }

if ( !("basglm" %in% class(object))) {
# call lm
newobject <- bas.lm(
eval(object$call$formula),
data = eval(object$call$data, parent.frame()),
weights = eval(object$call$weights),
n.models = 1,
alpha = object$g,
initprobs = object$probne0,
prior = object$prior,
modelprior = object$modelprior,
update = NULL,
bestmodel = bestmodel
)

}
else {
glm_family = eval(object$family, parent.frame())$family
family <- get(glm_family, mode = "function", envir = parent.frame())
newobject <- bas.glm(
eval(object$call$formula),
data = eval(object$call$data, parent.frame()),
weights = eval(object$call$weights),
family = family,
n.models = 1L,
initprobs = object$probne0,
betaprior = object$betaprior,
modelprior = object$modelprior,
update = NULL,
bestmodel = bestmodel
)
}
newobject$probne0 = object$probne0
mf = object$call
mf$n.models = 1
mf$bestmodel = bestmodel
newobject$call = mf
return(newobject)
}
55 changes: 55 additions & 0 deletions man/extract_MPM.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

36 changes: 36 additions & 0 deletions tests/testthat/test-extract_models.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
test_that("extract Median Probability Model", {

data(Hald, package="BAS")
hald_bic = bas.lm(Y ~ ., data=Hald, alpha=13, prior="BIC",
modelprior = uniform())
hald_MPM_manual = bas.lm(Y ~ ., data=Hald, alpha=13, prior="BIC",
modelprior = uniform(),
n.models = 1L,
bestmodel = as.numeric(hald_bic$probne0 > .5)
)
hald_MPM = extract_MPM(hald_bic)
expect_equal(hald_bic$n.vars, hald_MPM$n.vars)
expect_equal(as.numeric(hald_bic$probne0 > .5),
as.vector(which.matrix(hald_MPM$which[1], hald_MPM$n.vars)))
expect_equal(predict(hald_bic, estimator="MPM")$fit,
predict(hald_MPM)$fit,
check.attributes = FALSE)

data(Pima.tr, package="MASS")
Pima_bas = bas.glm(type ~ ., data=Pima.tr, n.models= 2^7, method="BAS",
betaprior=CCH(a=1, b=nrow(Pima.tr)/2, s=0),
family=binomial(),
modelprior=uniform())
Pima_MPM_man = bas.glm(type ~ ., data=Pima.tr, method="BAS",
betaprior=CCH(a=1, b=nrow(Pima.tr)/2, s=0),
family=binomial(),
modelprior=uniform(),
n.models = 1L,
bestmodel = Pima_bas$probne0 > 0.5)


Pima_MPM = extract_MPM(Pima_bas)
expect_equal(as.numeric(Pima_bas$probne0 > .5),
as.vector(which.matrix(Pima_MPM$which[1], Pima_MPM$n.vars)))
expect_equal(coef(Pima_MPM)$coef, coef(Pima_MPM_man)$coef)
})

0 comments on commit ff8b146

Please sign in to comment.