-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
vip package: Variable interactions #158
Comments
Hi @HanLum, thanks for posting the issue here! We took #' Interaction effects
#'
#' Quantify the strength of two-way interaction effects using a simple
#' \emph{feature importance ranking measure} (FIRM) approach. For details, see
#' \href{https://arxiv.org/abs/1805.04755}{Greenwell et al. (2018)}.
#'
#' @param object A fitted model object (e.g., a \code{"randomForest"} object).
#'
#' @param feature_names Character string giving the names of the two features of
#' interest.
#'
#' @param parallel Logical indicating whether or not to run \code{partial} in
#' parallel using a backend provided by the \code{foreach} package. Default is
#' \code{FALSE}.
#'
#' @param paropts List containing additional options to be passed on to
#' \code{\link[foreach]{foreach}} when \code{parallel = TRUE}.
#'
#' @param ... Additional optional arguments to be passed on to
#' \code{\link[pdp]{partial}}.
#'
#' @details This function quantifies the strength of interaction between
#' features $X_1$ and $X_2$ by measuring the change in variance along slices of
#' the partial dependence of $X_1$ and $X_2$ on the target $Y$. See
#' \href{https://arxiv.org/abs/1805.04755}{Greenwell et al. (2018)} for
#' details and examples.
#'
#' @references
#' Greenwell, B. M., Boehmke, B. C., and McCarthy, A. J.: A Simple
#' and Effective Model-Based Variable Importance Measure. arXiv preprint
#' arXiv:1805.04755 (2018).
#'
#' @examples
#' \dontrun{
#' #
#' # The Friedman 1 benchmark problem
#' #
#'
#' # Load required packages
#' library(gbm)
#' library(ggplot2)
#' library(mlbench)
#'
#' # Simulate training data
#' trn <- vip::gen_friedman(500, seed = 101) # ?vip::gen_friedman
#'
#' #
#' # NOTE: The only interaction that actually occurs in the model from which
#' # these data are generated is between x.1 and x.2!
#' #
#'
#' # Fit a GBM to the training data
#' set.seed(102) # for reproducibility
#' fit <- gbm(y ~ ., data = trn, distribution = "gaussian", n.trees = 1000,
#' interaction.depth = 2, shrinkage = 0.01, bag.fraction = 0.8,
#' cv.folds = 5)
#' best_iter <- gbm.perf(fit, plot.it = FALSE, method = "cv")
#'
#' # Quantify relative interaction strength
#' all_pairs <- combn(paste0("x", 1:10), m = 2)
#' res <- NULL
#' for (i in seq_along(all_pairs[1, ])) {
#' interact <- vint(fit, feature_names = all_pairs[, i], n.trees = best_iter)
#' res <- rbind(res, interact)
#' }
#'
#' # Plot top 20 results
#' top_20 <- res[1L:20L, ]
#' ggplot(top_20, aes(x = reorder(Variables, Interaction), y = Interaction)) +
#' geom_col() +
#' coord_flip() +
#' xlab("") +
#' ylab("Interaction strength")
#' }
vint <- function(object, feature_names, progress = "none", parallel = FALSE,
paropts = NULL, ...) {
# warning("This function is experimental, use at your own risk!", call. = FALSE)
# FIXME: Should we force `chull = FALSE` in the call to `pdp::partial()`?
all.pairs <- utils::combn(feature_names, m = 2)
ints <- plyr::aaply(
all.pairs, .margins = 2, .progress = progress, .parallel = parallel,
.paropts = paropts,
.fun = function(x) {
pd <- pdp::partial(object, pred.var = x, ...)
mean(c(
stats::sd(tapply(pd$yhat, INDEX = pd[[x[1L]]], FUN = stats::sd)),
stats::sd(tapply(pd$yhat, INDEX = pd[[x[2L]]], FUN = stats::sd))
))
})
ints <- data.frame(
"Variables" = paste0(all.pairs[1L, ], "*", all.pairs[2L, ]),
"Interaction" = ints
)
ints <- ints[order(ints["Interaction"], decreasing = TRUE), ]
tibble::as_tibble(ints)
}
Quick example: # Load required packages
library(gbm)
library(ggplot2)
library(mlbench)
# Simulate training data
trn <- vip::gen_friedman(500, seed = 101) # ?vip::gen_friedman
#
# NOTE: The only interaction that actually occurs in the model from which
# these data are generated is between x.1 and x.2!
#
# Fit a GBM to the training data
set.seed(102) # for reproducibility
fit <- gbm(y ~ ., data = trn, distribution = "gaussian", n.trees = 1000,
interaction.depth = 2, shrinkage = 0.01, bag.fraction = 0.8,
cv.folds = 5)
best_iter <- gbm.perf(fit, plot.it = FALSE, method = "cv")
# Quantify relative interaction strength
all_pairs <- combn(paste0("x", 1:10), m = 2)
res <- NULL
for (i in seq_along(all_pairs[1, ])) {
interact <- vint(fit, feature_names = all_pairs[, i], n.trees = best_iter)
res <- rbind(res, interact)
}
print(res)
# # A tibble: 45 × 2
# Variables Interaction
# <chr> <dbl>
# 1 x1*x2 4.98e- 1
# 2 x1*x3 3.60e- 3
# 3 x1*x4 8.61e- 2
# 4 x1*x5 7.69e- 2
# 5 x1*x6 0
# 6 x1*x7 0
# 7 x1*x8 0
# 8 x1*x9 0
# 9 x1*x10 1.06e-16
# 10 x2*x3 6.94e- 3
# # ℹ 35 more rows
# # ℹ Use `print(n = ...)` to see more rows |
Perfect, thank you so much @bgreenwell! 😊 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I'm trying to use vip to plot variable interaction heatmaps using the vivid package (EIX doesn't like my model). Vivid relies on the vint function which no longer exists in vip - is there an alternative way of calculating variable interactions for this purpose? Thank you @bgreenwell
The text was updated successfully, but these errors were encountered: