Skip to content

Commit

Permalink
Merge pull request #156 from quantgroup/group-feature-importance
Browse files Browse the repository at this point in the history
Joint feature importance
  • Loading branch information
christophM authored Jan 12, 2021
2 parents 9fdecc0 + b446afe commit 814cdb9
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 14 deletions.
66 changes: 53 additions & 13 deletions R/FeatureImp.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,18 @@
#' # Plot the results directly
#' plot(imp)
#'
#' # We can calculate feature importance for a subset of features
#' imp <- FeatureImp$new(mod, loss = "mae", features = c("crim", "zn", "indus"))
#' plot(imp)
#'
#'# We can calculate joint importance of groups of features
#'groups = list(
#' grp1 = c("crim", "zn", "indus", "chas"),
#' grp2 = c("nox", "rm", "age", "dis"),
#' grp3 = c("rad", "tax", "ptratio", "black", "lstat")
#')
#'imp <- FeatureImp$new(mod, loss = "mae", features = groups)
#'plot(imp)
#'
#' # FeatureImp also works with multiclass classification.
#' # In this case, the importance measurement regards all classes
Expand Down Expand Up @@ -122,6 +134,11 @@ FeatureImp <- R6::R6Class("FeatureImp",
#' How often should the shuffling of the feature be repeated?
#' The higher the number of repetitions the more stable and accurate the
#' results become.
#' @param features (`character or list`)\cr
#' For which features do you want importance scores calculated. The default
#' value of `NULL` implies all features. Use a named list of character vectors
#' to define groups of features for which joint importance will be calculated.
#' See examples.
#' @return (data.frame)\cr
#' data.frame with the results of the feature importance computation. One
#' row per feature with the following columns:
Expand All @@ -134,11 +151,13 @@ FeatureImp <- R6::R6Class("FeatureImp",
#' plots, the median importance over the repetitions as a point.
#'
initialize = function(predictor, loss, compare = "ratio",
n.repetitions = 5) {
n.repetitions = 5, features = NULL) {

assert_choice(compare, c("ratio", "difference"))
assert_number(n.repetitions)
self$compare <- compare


if (!inherits(loss, "function")) {
## Only allow metrics from Metrics package
allowedLosses <- c(
Expand Down Expand Up @@ -166,6 +185,18 @@ FeatureImp <- R6::R6Class("FeatureImp",
warning("Model error is 0, switching from compare='ratio' to compare='difference'")
self$compare <- "difference"
}

# process features argument
if (is.null(features)) {
features <- private$sampler$feature.names
}
if (!is.list(features)) {
features <- as.list(features)
names(features) <- unlist(features)
}
assert_subset(unique(unlist(features)), private$sampler$feature.names, empty.ok = FALSE)
self$features <- features

# suppressing package startup messages
suppressPackageStartupMessages(private$run(self$predictor$batch.size))
},
Expand All @@ -187,7 +218,13 @@ FeatureImp <- R6::R6Class("FeatureImp",
#' depending on whether the importance was calculated as difference
#' between original model error and model error after permutation or as
#' ratio.
compare = NULL
compare = NULL,

#' @field features (`list`)\cr Features for which importance scores are to
#' be calculated. The names are the feature/group names, while the contents
#' specify which feature(s) are to be permuted.
features = NULL

),

private = list(
Expand All @@ -204,20 +241,21 @@ FeatureImp <- R6::R6Class("FeatureImp",
private$dataSample <- private$getData()
result <- NULL

estimate_feature_imp <- function(feature,
estimate_feature_imp <- function(group,
features,
data.sample,
y,
n.repetitions,
y.names,
pred,
loss) {

cnames <- setdiff(colnames(data.sample), y.names)
qResults <- data.table::data.table()
y.vec <- data.table::data.table()
for (repi in 1:n.repetitions) {
mg <- MarginalGenerator$new(data.sample, data.sample,
features = feature, n.sample.dist = 1, y = y, cartesian = FALSE,
features = features, n.sample.dist = 1, y = y, cartesian = FALSE,
id.dist = TRUE
)
while (!mg$finished) {
Expand All @@ -231,7 +269,7 @@ FeatureImp <- R6::R6Class("FeatureImp",
}
# AGGREGATE measurements
results <- data.table::data.table(
feature = feature, actual = y.vec[[1]], predicted = qResults[[1]],
feature = group, actual = y.vec[[1]], predicted = qResults[[1]],
num_rep = rep(1:n.repetitions, each = nrow(data.sample))
)
results <- results[, list("permutation_error" = loss(actual, predicted)),
Expand All @@ -248,19 +286,21 @@ FeatureImp <- R6::R6Class("FeatureImp",
loss <- self$loss

result <- rbindlist(unname(
future.apply::future_lapply(private$sampler$feature.names, function(x) {
estimate_feature_imp(x,
future.apply::future_mapply(estimate_feature_imp,
group = names(self$features),
features = self$features,
MoreArgs = list(
data.sample = data.sample,
y = y,
n.repetitions = n.repetitions,
y.names = y.names,
pred = pred,
loss = loss
)
},
future.seed = TRUE,
future.globals = FALSE,
future.packages = loadedNamespaces()
),
SIMPLIFY = FALSE,
future.seed = TRUE,
future.globals = FALSE,
future.packages = loadedNamespaces()
)
), use.names = TRUE)

Expand Down
30 changes: 29 additions & 1 deletion man/FeatureImp.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

46 changes: 46 additions & 0 deletions tests/testthat/test-FeatureImp.R
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,49 @@ test_that("Feature Importance 0", {
fimp <- FeatureImp$new(pred, loss = "mae", n.repetitions = 3)
expect_equal(fimp$results$importance[3], 1)
})

test_that("FeatureImp works for a subset of features", {
var.imp <- FeatureImp$new(predictor1, loss = "mse", features = c("a", "b"))
dat <- var.imp$results
expect_class(dat, "data.frame")
expect_false("data.table" %in% class(dat))
expect_equal(colnames(dat), expected_colnames)
expect_equal(nrow(dat), 2)
p <- plot(var.imp)
expect_s3_class(p, c("gg", "ggplot"))
p
})

test_that("Invalid feature names are caught", {
expect_error(
FeatureImp$new(predictor1, loss = "mse", features = c("x", "y", "z")),
"failed: Must be a subset of {'a','b','c','d'}, but is {'x','y','z'}",
fixed = TRUE
)
})

test_that("FeatureImp works for groups of features", {
groups = list(ab = c("a", "b"), cd = c("c", "d"))
var.imp <- FeatureImp$new(predictor1, loss = "mse", features = groups)
dat <- var.imp$results
expect_class(dat, "data.frame")
expect_false("data.table" %in% class(dat))
expect_equal(colnames(dat), expected_colnames)
expect_equal(nrow(dat), 2)
p <- plot(var.imp)
expect_s3_class(p, c("gg", "ggplot"))
p
})

test_that("FeatureImp works for overlapping groups of features", {
groups = list(ab = c("a", "b"), bc = c("b", "c"))
var.imp <- FeatureImp$new(predictor1, loss = "mse", features = groups)
dat <- var.imp$results
expect_class(dat, "data.frame")
expect_false("data.table" %in% class(dat))
expect_equal(colnames(dat), expected_colnames)
expect_equal(nrow(dat), 2)
p <- plot(var.imp)
expect_s3_class(p, c("gg", "ggplot"))
p
})

0 comments on commit 814cdb9

Please sign in to comment.